diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bc3884a5d2..ac24623eef 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,11 @@ exclude: ^.github/actions/assigner/dist repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v4.4.0 hooks: - id: check-yaml - id: trailing-whitespace + exclude: ^docs - id: check-added-large-files args: - --maxkb=1000 @@ -13,18 +14,14 @@ repos: - id: mixed-line-ending args: - --fix=lf - - repo: https://github.com/psf/black - rev: 22.6.0 - hooks: - - id: black - exclude: ^examples/custom_converters/elu_converter/setup.py + exclude: ^docs - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v14.0.6 + rev: v16.0.6 hooks: - id: clang-format types_or: [c++, c, cuda] - repo: https://github.com/keith/pre-commit-buildifier - rev: 5.1.0 + rev: 6.1.0.2 hooks: - id: buildifier args: @@ -34,6 +31,26 @@ repos: rev: v0.13 hooks: - id: validate-pyproject + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) + - repo: https://github.com/pre-commit/mirrors-mypy + rev: 'v1.4.1' + hooks: + - id: mypy + exclude: "^py/torch_tensorrt/fx|^examples|^tests|^tools|^docs|noxfile.py|setup.py|versions.py" + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.0.278 + hooks: + - id: ruff + - repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black + exclude: ^examples/custom_converters/elu_converter/setup.py|^docs - repo: local hooks: - id: dont-commit-upstream @@ -42,4 +59,3 @@ repos: exclude: "^.pre-commit-config.yaml" language: pygrep types: [text] - diff --git a/BUILD b/BUILD index 83e321c0b9..c40d52e0f9 100644 --- a/BUILD +++ b/BUILD @@ -47,8 +47,8 @@ pkg_tar( ":windows": ["//cpp/lib:torch_tensorrt.dll"], "//conditions:default": [ "//cpp/lib:libtorchtrt.so", - "//cpp/lib:libtorchtrt_runtime.so", "//cpp/lib:libtorchtrt_plugins.so", + "//cpp/lib:libtorchtrt_runtime.so", ], }), mode = "0755", @@ -74,9 +74,9 @@ pkg_tar( extension = "tar.gz", package_dir = "torch_tensorrt", deps = [ - ":lib", ":include", ":include_core", + ":lib", ] + select({ ":windows": [], "//conditions:default": [":bin"], diff --git a/core/BUILD b/core/BUILD index d802a8eff6..23ef0bfec0 100644 --- a/core/BUILD +++ b/core/BUILD @@ -27,14 +27,14 @@ cc_library( ], deps = [ "//core/conversion", - "//core/runtime", "//core/lowering", "//core/partitioning", + "//core/runtime", "//core/util/logging", "@tensorrt//:nvinfer", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), alwayslink = True, ) diff --git a/core/conversion/BUILD b/core/conversion/BUILD index c7d7b68934..318aebdfbf 100644 --- a/core/conversion/BUILD +++ b/core/conversion/BUILD @@ -20,16 +20,16 @@ cc_library( "conversion.h", ], deps = [ - "@tensorrt//:nvinfer", - "//core/conversion/var", "//core/conversion/conversionctx", "//core/conversion/converters", "//core/conversion/evaluators", + "//core/conversion/var", "//core/ir", "//core/util:prelude", + "@tensorrt//:nvinfer", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) diff --git a/core/conversion/conversionctx/BUILD b/core/conversion/conversionctx/BUILD index 6626ae4457..3c2dea1a3e 100644 --- a/core/conversion/conversionctx/BUILD +++ b/core/conversion/conversionctx/BUILD @@ -19,12 +19,12 @@ cc_library( "ConversionCtx.h", ], deps = [ - "@tensorrt//:nvinfer", - "//core/util:prelude", "//core/ir", + "//core/util:prelude", + "@tensorrt//:nvinfer", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index 518c6a2ded..e184dd787e 100755 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -19,12 +19,12 @@ cc_library( "Weights.h", ], deps = [ - "@tensorrt//:nvinfer", - "//core/util:prelude", "//core/conversion/conversionctx", + "//core/util:prelude", + "@tensorrt//:nvinfer", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), alwayslink = True, ) @@ -39,12 +39,12 @@ cc_library( ], deps = [ ":weights", - "@tensorrt//:nvinfer", - "//core/util:prelude", "//core/conversion/conversionctx", + "//core/util:prelude", + "@tensorrt//:nvinfer", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), alwayslink = True, ) @@ -91,16 +91,16 @@ cc_library( "converters.h", ], deps = [ - "@tensorrt//:nvinfer", - "//core/util:prelude", - "//core/conversion/var", - "//core/conversion/tensorcontainer", + ":converter_util", "//core/conversion/conversionctx", + "//core/conversion/tensorcontainer", + "//core/conversion/var", "//core/plugins:torch_tensorrt_plugins", - ":converter_util", + "//core/util:prelude", + "@tensorrt//:nvinfer", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), alwayslink = True, ) diff --git a/core/conversion/evaluators/BUILD b/core/conversion/evaluators/BUILD index bd2302f142..03789336be 100644 --- a/core/conversion/evaluators/BUILD +++ b/core/conversion/evaluators/BUILD @@ -24,12 +24,12 @@ cc_library( "evaluators.h", ], deps = [ - "//core/util:prelude", - "//core/conversion/var", "//core/conversion/tensorcontainer", + "//core/conversion/var", + "//core/util:prelude", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), alwayslink = True, ) diff --git a/core/conversion/tensorcontainer/BUILD b/core/conversion/tensorcontainer/BUILD index f2d508a3fe..187a84b787 100644 --- a/core/conversion/tensorcontainer/BUILD +++ b/core/conversion/tensorcontainer/BUILD @@ -19,11 +19,11 @@ cc_library( "TensorContainer.h", ], deps = [ - "@tensorrt//:nvinfer", "//core/util:prelude", + "@tensorrt//:nvinfer", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), alwayslink = True, ) diff --git a/core/conversion/var/BUILD b/core/conversion/var/BUILD index e06e8c64d4..b0df6b02f8 100644 --- a/core/conversion/var/BUILD +++ b/core/conversion/var/BUILD @@ -20,13 +20,13 @@ cc_library( "Var_inl.h", ], deps = [ - "@tensorrt//:nvinfer", - "//core/util:prelude", "//core/conversion/converters:converter_util", - "//core/conversion/tensorcontainer:tensorcontainer", + "//core/conversion/tensorcontainer", + "//core/util:prelude", + "@tensorrt//:nvinfer", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), alwayslink = True, ) diff --git a/core/ir/BUILD b/core/ir/BUILD index 64de19d6de..1d4d4832b8 100644 --- a/core/ir/BUILD +++ b/core/ir/BUILD @@ -22,11 +22,11 @@ cc_library( "ir.h", ], deps = [ - "@tensorrt//:nvinfer", "//core/util:prelude", + "@tensorrt//:nvinfer", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) diff --git a/core/lowering/BUILD b/core/lowering/BUILD index 9c08752bb5..81ad86e906 100644 --- a/core/lowering/BUILD +++ b/core/lowering/BUILD @@ -22,12 +22,12 @@ cc_library( "lowering.h", ], deps = [ + "//core/ir", "//core/lowering/passes", "//core/util:prelude", - "//core/ir", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), alwayslink = True, ) diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index 2f6a15e650..939789402d 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -49,7 +49,7 @@ cc_library( "//core/util:prelude", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) diff --git a/core/partitioning/BUILD b/core/partitioning/BUILD index 4204939684..2acfec2cc4 100644 --- a/core/partitioning/BUILD +++ b/core/partitioning/BUILD @@ -21,16 +21,16 @@ cc_library( "partitioning.h", ], deps = [ - "//core/util:prelude", - "//core/ir", "//core/conversion", + "//core/ir", "//core/lowering", "//core/partitioning/partitioningctx", "//core/partitioning/partitioninginfo", "//core/partitioning/segmentedblock", + "//core/util:prelude", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), alwayslink = True, ) diff --git a/core/partitioning/partitioningctx/BUILD b/core/partitioning/partitioningctx/BUILD index 6895f8d451..dec1a1f992 100644 --- a/core/partitioning/partitioningctx/BUILD +++ b/core/partitioning/partitioningctx/BUILD @@ -19,14 +19,14 @@ cc_library( "PartitioningCtx.h", ], deps = [ - "//core/util:prelude", - "//core/ir", "//core/conversion", - "//core/partitioning/segmentedblock", + "//core/ir", "//core/partitioning/partitioninginfo", + "//core/partitioning/segmentedblock", + "//core/util:prelude", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), alwayslink = True, ) diff --git a/core/partitioning/partitioninginfo/BUILD b/core/partitioning/partitioninginfo/BUILD index 74e34d134b..a4c1031295 100644 --- a/core/partitioning/partitioninginfo/BUILD +++ b/core/partitioning/partitioninginfo/BUILD @@ -19,13 +19,13 @@ cc_library( "PartitioningInfo.h", ], deps = [ - "//core/util:prelude", - "//core/ir", "//core/conversion", + "//core/ir", "//core/lowering", + "//core/util:prelude", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), alwayslink = True, ) diff --git a/core/partitioning/segmentedblock/BUILD b/core/partitioning/segmentedblock/BUILD index 8efe1e6b0a..0ab6246c81 100644 --- a/core/partitioning/segmentedblock/BUILD +++ b/core/partitioning/segmentedblock/BUILD @@ -19,13 +19,13 @@ cc_library( "SegmentedBlock.h", ], deps = [ - "//core/util:prelude", - "//core/ir", "//core/conversion", + "//core/ir", "//core/lowering", + "//core/util:prelude", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), alwayslink = True, ) diff --git a/core/plugins/BUILD b/core/plugins/BUILD index b720a683cb..6fcb0e8934 100644 --- a/core/plugins/BUILD +++ b/core/plugins/BUILD @@ -29,12 +29,12 @@ cc_library( "-lpthread", ], deps = [ + "//core/util:prelude", "@tensorrt//:nvinfer", "@tensorrt//:nvinferplugin", - "//core/util:prelude", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), alwayslink = True, ) diff --git a/core/runtime/BUILD b/core/runtime/BUILD index 669feda90e..ef70610a39 100644 --- a/core/runtime/BUILD +++ b/core/runtime/BUILD @@ -31,12 +31,12 @@ cc_library( "-lstdc++fs", ], deps = [ - "@tensorrt//:nvinfer", - "//core/util:prelude", "//core/plugins:torch_tensorrt_plugins", + "//core/util:prelude", + "@tensorrt//:nvinfer", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), alwayslink = True, ) diff --git a/core/util/BUILD b/core/util/BUILD index a221a6e080..e2f6684830 100644 --- a/core/util/BUILD +++ b/core/util/BUILD @@ -34,7 +34,7 @@ cc_library( ":macros", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) @@ -68,7 +68,7 @@ cc_library( "@tensorrt//:nvinfer", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) @@ -81,12 +81,12 @@ cc_library( "trt_util.h", ], deps = [ - "@tensorrt//:nvinfer", - "//core/util/logging", ":macros", + "//core/util/logging", + "@tensorrt//:nvinfer", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) diff --git a/core/util/logging/BUILD b/core/util/logging/BUILD index de826922d1..9aa91abee9 100644 --- a/core/util/logging/BUILD +++ b/core/util/logging/BUILD @@ -22,7 +22,7 @@ cc_library( "@tensorrt//:nvinfer", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) diff --git a/cpp/bin/torchtrtc/BUILD b/cpp/bin/torchtrtc/BUILD index b95809dc20..ba0e4eb2c0 100644 --- a/cpp/bin/torchtrtc/BUILD +++ b/cpp/bin/torchtrtc/BUILD @@ -25,15 +25,15 @@ cc_binary( "-ldl", ], deps = [ - "//third_party/args", "//cpp:torch_tensorrt", + "//third_party/args", ] + select({ ":use_pre_cxx11_abi": [ - "@libtorch_pre_cxx11_abi//:libtorch", "@libtorch_pre_cxx11_abi//:caffe2", + "@libtorch_pre_cxx11_abi//:libtorch", ], "//conditions:default": [ - "@libtorch//:libtorch", + "@libtorch", "@libtorch//:caffe2", ], }), diff --git a/docs/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py b/docs/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py index a0b38ecc21..f73bd1e780 100644 --- a/docs/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py +++ b/docs/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py @@ -16,6 +16,7 @@ # %% + # We begin by defining a model class Model(torch.nn.Module): def __init__(self) -> None: diff --git a/docs/v1.4.0/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py b/docs/v1.4.0/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py index a0b38ecc21..f73bd1e780 100644 --- a/docs/v1.4.0/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py +++ b/docs/v1.4.0/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py @@ -16,6 +16,7 @@ # %% + # We begin by defining a model class Model(torch.nn.Module): def __init__(self) -> None: diff --git a/docsrc/getting_started/getting_started_with_windows.rst b/docsrc/getting_started/getting_started_with_windows.rst index 68b36dd939..edb3262d66 100644 --- a/docsrc/getting_started/getting_started_with_windows.rst +++ b/docsrc/getting_started/getting_started_with_windows.rst @@ -117,7 +117,7 @@ Building With Visual Studio Code e.g. /.vscode/settings.json .. code-block:: json - + { "cmake.generator": "Ninja", "cmake.configureSettings": { diff --git a/examples/dynamo/torch_compile_advanced_usage.py b/examples/dynamo/torch_compile_advanced_usage.py index 1d301a16a8..96146a43d8 100644 --- a/examples/dynamo/torch_compile_advanced_usage.py +++ b/examples/dynamo/torch_compile_advanced_usage.py @@ -15,6 +15,7 @@ # %% + # We begin by defining a model class Model(torch.nn.Module): def __init__(self) -> None: diff --git a/examples/fx/hugging_face_torchdynamo_example.py b/examples/fx/hugging_face_torchdynamo_example.py index 388ccf2e47..a28159f32a 100644 --- a/examples/fx/hugging_face_torchdynamo_example.py +++ b/examples/fx/hugging_face_torchdynamo_example.py @@ -218,6 +218,7 @@ def check_correctness(args, mod, inputs, optimize_ctx, optimize_name): synchronize = torch.cuda.synchronize + # timing function to record the repeated run time def timed(model, model_iter_fn, train_inputs, timings=1, return_result=False): synchronize() diff --git a/examples/int8/training/vgg16/finetune_qat.py b/examples/int8/training/vgg16/finetune_qat.py index 6ec20a9a46..0414af00de 100644 --- a/examples/int8/training/vgg16/finetune_qat.py +++ b/examples/int8/training/vgg16/finetune_qat.py @@ -184,7 +184,6 @@ def calibrate_model( def main(): - global state global classes global writer diff --git a/examples/torchtrt_runtime_example/network.py b/examples/torchtrt_runtime_example/network.py index b82cd11378..5eac6399bc 100644 --- a/examples/torchtrt_runtime_example/network.py +++ b/examples/torchtrt_runtime_example/network.py @@ -2,6 +2,7 @@ import torch.nn as nn import torch_tensorrt as torchtrt + # create a simple norm layer. # This norm layer uses NormalizePlugin from Torch-TensorRT class Norm(torch.nn.Module): @@ -27,7 +28,6 @@ def forward(self, x): def main(): - model = ConvGelu().eval().cuda() scripted_model = torch.jit.script(model) diff --git a/py/torch_tensorrt/_Device.py b/py/torch_tensorrt/_Device.py index 3eaa5aad4e..3ac276db13 100644 --- a/py/torch_tensorrt/_Device.py +++ b/py/torch_tensorrt/_Device.py @@ -1,13 +1,22 @@ +import sys +from typing import Any, Optional, Tuple + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +import warnings + import torch +from torch_tensorrt import logging # from torch_tensorrt import _enums import tensorrt as trt -from torch_tensorrt import logging -import warnings try: from torch_tensorrt import _C -except: +except ImportError: warnings.warn( "Unable to import torchscript frontend core and torch-tensorrt runtime. Some dependent features may be unavailable." ) @@ -24,12 +33,14 @@ class Device(object): allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed """ - device_type = None #: (torch_tensorrt.DeviceType): Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. - gpu_id = -1 #: (int) Device ID for target GPU - dla_core = -1 #: (int) Core ID for target DLA core - allow_gpu_fallback = False #: (bool) Whether falling back to GPU if DLA cannot support an op should be allowed + device_type: Optional[ + trt.DeviceType + ] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. + gpu_id: int = -1 #: Device ID for target GPU + dla_core: int = -1 #: Core ID for target DLA core + allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): """__init__ Method for torch_tensorrt.Device Device accepts one of a few construction patterns @@ -127,29 +138,28 @@ def _to_internal(self) -> _C.Device: def _to_serialized_rt_device(self) -> str: internal_dev = self._to_internal() - return internal_dev._to_serialized_rt_device() + serialized_rt_device: str = internal_dev._to_serialized_rt_device() + return serialized_rt_device @classmethod - def _from_torch_device(cls, torch_dev: torch.device): + def _from_torch_device(cls, torch_dev: torch.device) -> Self: if torch_dev.type != "cuda": raise ValueError('Torch Device specs must have type "cuda"') gpu_id = torch_dev.index return cls(gpu_id=gpu_id) @classmethod - def _current_device(cls): - try: - dev = _C._get_current_device() - except RuntimeError: - logging.log(logging.Level.Error, "Cannot get current device") - return None + def _current_device(cls) -> Self: + dev = _C._get_current_device() return cls(gpu_id=dev.gpu_id) @staticmethod - def _parse_device_str(s): + def _parse_device_str(s: str) -> Tuple[trt.DeviceType, int]: s = s.lower() spec = s.split(":") if spec[0] == "gpu" or spec[0] == "cuda": return (trt.DeviceType.GPU, int(spec[1])) elif spec[0] == "dla": return (trt.DeviceType.DLA, int(spec[1])) + else: + raise ValueError(f"Unknown device type {spec[0]}") diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 8d3a842a47..796b8c6253 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -1,8 +1,7 @@ from enum import Enum -from typing import List, Dict, Any, Tuple, Optional +from typing import Any, Dict, List, Optional, Sequence, Tuple import torch - from torch_tensorrt import _enums @@ -27,22 +26,26 @@ class _ShapeMode(Enum): STATIC = 0 DYNAMIC = 1 - shape_mode = None #: (torch_tensorrt.Input._ShapeMode): Is input statically or dynamically shaped - shape = None #: (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }`` - dtype = ( + shape_mode: Optional[ + _ShapeMode + ] = None #: Is input statically or dynamically shaped + shape: Optional[ + Tuple[int, ...] | Dict[str, Tuple[int, ...]] + ] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }`` + dtype: _enums.dtype = ( # type: ignore[name-defined] _enums.dtype.unknown ) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32) - _explicit_set_dtype = False - format = ( + _explicit_set_dtype: bool = False + format: _enums.TensorFormat = ( # type: ignore[name-defined] _enums.TensorFormat.contiguous ) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW) - DOMAIN_OFFSET = 2.0 - low_tensor_domain_incl = 0.0 - high_tensor_domain_excl = low_tensor_domain_incl + DOMAIN_OFFSET - torch_dtype = torch.float32 + DOMAIN_OFFSET: float = 2.0 + low_tensor_domain_incl: float = 0.0 + high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET + torch_dtype: torch.dtype = torch.float32 - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """__init__ Method for torch_tensorrt.Input Input accepts one of a few construction patterns @@ -93,7 +96,7 @@ def __init__(self, *args, **kwargs): self.shape_mode = Input._ShapeMode.STATIC elif len(args) == 0: - if not ("shape" in kwargs) and not ( + if "shape" not in kwargs and not ( all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"]) ): raise ValueError( @@ -176,15 +179,20 @@ def __str__(self) -> str: str(self.tensor_domain[1]), ) elif self.shape_mode == Input._ShapeMode.DYNAMIC: - return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={}, domain=[{}, {}))".format( - self.shape["min_shape"], - self.shape["opt_shape"], - self.shape["max_shape"], - str(self.dtype), - str(self.format), - str(self.tensor_domain[0]), - str(self.tensor_domain[1]), - ) + if isinstance(self.shape, dict): + return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={}, domain=[{}, {}))".format( + self.shape["min_shape"], + self.shape["opt_shape"], + self.shape["max_shape"], + str(self.dtype), + str(self.format), + str(self.tensor_domain[0]), + str(self.tensor_domain[1]), + ) + else: + raise RuntimeError( + f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})" + ) else: raise RuntimeError("Unknown input shape mode") @@ -200,7 +208,7 @@ def _supported_input_size_type(input_size: Any) -> bool: return False @staticmethod - def _parse_dtype(dtype: Any) -> _enums.dtype: + def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined] if isinstance(dtype, torch.dtype): if dtype == torch.long: return _enums.dtype.long @@ -228,7 +236,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: ) @staticmethod - def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: + def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: # type: ignore[name-defined] if dtype == _enums.dtype.long: return torch.long elif dtype == _enums.dtype.int32: @@ -244,10 +252,10 @@ def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: return torch.float32 def is_trt_dtype(self) -> bool: - return self.dtype != _enums.dtype.long + return bool(self.dtype != _enums.dtype.long) @staticmethod - def _parse_format(format: Any) -> _enums.TensorFormat: + def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defined] if isinstance(format, torch.memory_format): if format == torch.contiguous_format: return _enums.TensorFormat.contiguous @@ -267,7 +275,9 @@ def _parse_format(format: Any) -> _enums.TensorFormat: ) @staticmethod - def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple: + def _parse_tensor_domain( + domain: Optional[Tuple[float, float]] + ) -> Tuple[float, float]: """ Produce a tuple of integers which specifies a tensor domain in the interval format: [lo, hi) @@ -287,8 +297,8 @@ def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple: domain_lo, domain_hi = domain # Validate type and provided values for domain - valid_type_lo = isinstance(domain_lo, int) or isinstance(domain_lo, float) - valid_type_hi = isinstance(domain_hi, int) or isinstance(domain_hi, float) + valid_type_lo = isinstance(domain_lo, (int, float)) + valid_type_hi = isinstance(domain_hi, (int, float)) if not valid_type_lo: raise ValueError( @@ -346,7 +356,7 @@ def from_tensor( @classmethod def from_tensors( - cls, ts: torch.Tensor, disable_memory_format_check: bool = False + cls, ts: Sequence[torch.Tensor], disable_memory_format_check: bool = False ) -> List["Input"]: """ Produce a list of Inputs which contain @@ -366,7 +376,9 @@ def from_tensors( for t in ts ] - def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor: + def example_tensor( + self, optimization_profile_field: Optional[str] = None + ) -> torch.Tensor: """ Get an example tensor of the shape specified by the Input object @@ -376,38 +388,41 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor Returns: A PyTorch Tensor """ - if optimization_profile_field is not None: - try: - assert any( - [ - optimization_profile_field == field_name - for field_name in ["min_shape", "opt_shape", "max_shape"] - ] - ) - except: + if self.shape_mode == Input._ShapeMode.STATIC: + if optimization_profile_field is not None: raise ValueError( - "Invalid field name, expected one of min_shape, opt_shape, max_shape" + "Specified a optimization profile field but the input is static" ) + else: + if isinstance(self.shape, tuple): + return torch.rand(self.shape).to(dtype=self.torch_dtype) + else: + RuntimeError( + f"Input shape is dynamic but shapes are not provided as sequence (found: {self.shape})" + ) + else: + if optimization_profile_field is not None: + try: + assert any( + optimization_profile_field == field_name + for field_name in ["min_shape", "opt_shape", "max_shape"] + ) + except AssertionError: + raise ValueError( + "Invalid field name, expected one of min_shape, opt_shape, max_shape" + ) - if ( - optimization_profile_field is not None - and self.shape_mode == Input._ShapeMode.STATIC - ): - raise ValueError( - "Specified a optimization profile field but the input is static" - ) - - if ( - optimization_profile_field is None - and self.shape_mode == Input._ShapeMode.DYNAMIC - ): - raise ValueError( - "Requested an example tensor from a dynamic shaped input but did not specific which profile field to use." - ) + if isinstance(self.shape, dict): + return torch.rand(self.shape[optimization_profile_field]).to( + dtype=self.torch_dtype + ) + else: + raise RuntimeError( + f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})" + ) - if self.shape_mode == Input._ShapeMode.STATIC: - return torch.rand(self.shape).to(dtype=self.torch_dtype) - else: - return torch.rand(self.shape[optimization_profile_field]).to( - dtype=self.torch_dtype - ) + else: + raise ValueError( + "Requested an example tensor from a dynamic shaped input but did not specific which profile field to use." + ) + raise diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index 7743aec3f1..1e9ce7e129 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -1,24 +1,24 @@ import ctypes -import glob import os -import sys import platform -import warnings -from packaging import version +import sys +from typing import Dict, List + from torch_tensorrt._version import ( - __version__, __cuda_version__, __cudnn_version__, __tensorrt_version__, ) +from packaging import version + if sys.version_info < (3,): raise Exception( "Python 2 has reached end-of-life and is not supported by Torch-TensorRT" ) -def _parse_semver(version): +def _parse_semver(version: str) -> Dict[str, str]: split = version.split(".") if len(split) < 3: split.append("") @@ -26,7 +26,7 @@ def _parse_semver(version): return {"major": split[0], "minor": split[1], "patch": split[2]} -def _find_lib(name, paths): +def _find_lib(name: str, paths: List[str]) -> str: for path in paths: libpath = os.path.join(path, name) if os.path.isfile(libpath): @@ -36,8 +36,8 @@ def _find_lib(name, paths): try: - import tensorrt -except: + import tensorrt # noqa: F401 +except ImportError: cuda_version = _parse_semver(__cuda_version__) cudnn_version = _parse_semver(__cudnn_version__) tensorrt_version = _parse_semver(__tensorrt_version__) @@ -82,24 +82,18 @@ def _find_lib(name, paths): ctypes.CDLL(_find_lib(lib, LINUX_PATHS)) import torch - -from torch_tensorrt._compile import * -from torch_tensorrt._util import * -from torch_tensorrt import ts -from torch_tensorrt import ptq -from torch_tensorrt._enums import * -from torch_tensorrt import logging -from torch_tensorrt._Input import Input -from torch_tensorrt._Device import Device - -from torch_tensorrt import fx +from torch_tensorrt._compile import * # noqa: F403 +from torch_tensorrt._Device import Device # noqa: F401 +from torch_tensorrt._enums import * # noqa: F403 +from torch_tensorrt._Input import Input # noqa: F401 +from torch_tensorrt._utils import * # noqa: F403 if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): - from torch_tensorrt import dynamo - from torch_tensorrt.dynamo import backend + from torch_tensorrt import dynamo # noqa: F401 + from torch_tensorrt.dynamo import backend # noqa: F401 -def _register_with_torch(): +def _register_with_torch() -> None: trtorch_dir = os.path.dirname(__file__) torch.ops.load_library(trtorch_dir + "/lib/libtorchtrt.so") diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index e83acf166f..8a95d6eada 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -1,13 +1,29 @@ -from typing import List, Dict, Any -import torch_tensorrt.ts +from enum import Enum +from typing import Any, Callable, List, Optional, Sequence, Set, TypeGuard -from torch_tensorrt import logging import torch import torch.fx -from enum import Enum - -import torch_tensorrt.fx +import torch_tensorrt.ts +from torch_tensorrt import logging +from torch_tensorrt._enums import dtype +from torch_tensorrt._Input import Input +from torch_tensorrt.dynamo.compile import compile as dynamo_compile +from torch_tensorrt.fx import InputTensorSpec +from torch_tensorrt.fx.lower import compile as fx_compile from torch_tensorrt.fx.utils import LowerPrecision +from torch_tensorrt.ts._compiler import compile as torchscript_compile + + +def _non_fx_input_interface( + inputs: Sequence[Input | torch.Tensor | InputTensorSpec], +) -> TypeGuard[List[Input | torch.Tensor]]: + return all(isinstance(i, torch.Tensor | Input) for i in inputs) + + +def _fx_input_interface( + inputs: Sequence[Input | torch.Tensor | InputTensorSpec], +) -> TypeGuard[List[InputTensorSpec | torch.Tensor]]: + return all(isinstance(i, torch.Tensor | InputTensorSpec) for i in inputs) class _IRType(Enum): @@ -42,10 +58,10 @@ def _parse_module_type(module: Any) -> _ModuleType: def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: - module_is_tsable = any([module_type == t for t in [_ModuleType.nn, _ModuleType.ts]]) - module_is_fxable = any([module_type == t for t in [_ModuleType.nn, _ModuleType.fx]]) + module_is_tsable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.ts]) + module_is_fxable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.fx]) - ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]]) + ir_targets_torchscript = any(ir == opt for opt in ["torchscript", "ts"]) ir_targets_fx = ir == "fx" ir_targets_dynamo = ir == "dynamo" ir_targets_torch_compile = ir == "torch_compile" @@ -80,10 +96,12 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: def compile( module: Any, - ir="default", - inputs=[], - enabled_precisions=set([torch.float]), - **kwargs, + ir: str = "default", + inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None, + enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, + **kwargs: Any, +) -> ( + torch.nn.Module | torch.jit.ScriptModule | torch.fx.GraphModule | Callable[..., Any] ): """Compile a PyTorch module for NVIDIA GPUs using TensorRT @@ -120,6 +138,11 @@ def compile( Returns: torch.nn.Module: Compiled Module, when run it will execute via TensorRT """ + input_list = inputs if inputs is not None else [] + enabled_precisions_set = ( + enabled_precisions if enabled_precisions is not None else {torch.float} + ) + module_type = _parse_module_type(module) target_ir = _get_target_ir(module_type, ir) if target_ir == _IRType.ts: @@ -130,55 +153,65 @@ def compile( "Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript", ) ts_mod = torch.jit.script(module) - return torch_tensorrt.ts.compile( - ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs + assert _non_fx_input_interface(input_list) + compiled_ts_module: torch.jit.ScriptModule = torchscript_compile( + ts_mod, + inputs=input_list, + enabled_precisions=enabled_precisions_set, + **kwargs, ) + return compiled_ts_module elif target_ir == _IRType.fx: if ( - torch.float16 in enabled_precisions - or torch_tensorrt.dtype.half in enabled_precisions + torch.float16 in enabled_precisions_set + or torch_tensorrt.dtype.half in enabled_precisions_set ): lower_precision = LowerPrecision.FP16 elif ( - torch.float32 in enabled_precisions - or torch_tensorrt.dtype.float in enabled_precisions + torch.float32 in enabled_precisions_set + or torch_tensorrt.dtype.float in enabled_precisions_set ): lower_precision = LowerPrecision.FP32 else: - raise ValueError(f"Precision {enabled_precisions} not supported on FX") + raise ValueError(f"Precision {enabled_precisions_set} not supported on FX") - return torch_tensorrt.fx.compile( + assert _fx_input_interface(input_list) + compiled_fx_module: torch.nn.Module = fx_compile( module, - inputs, + input_list, lower_precision=lower_precision, - max_batch_size=inputs[0].size(0), explicit_batch_dimension=True, dynamic_batch=False, **kwargs, ) + return compiled_fx_module elif target_ir == _IRType.dynamo: - from torch_tensorrt import Device - from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device import collections.abc + from torch_tensorrt import Device + from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs + if not isinstance(inputs, collections.abc.Sequence): inputs = [inputs] device = kwargs.get("device", Device._current_device()) torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device)) module = torch_tensorrt.dynamo.trace(module, torch_inputs, **kwargs) - return torch_tensorrt.dynamo.compile( + compiled_aten_module: torch.fx.GraphModule = dynamo_compile( module, - inputs=inputs, - enabled_precisions=enabled_precisions, + inputs=input_list, + enabled_precisions=enabled_precisions_set, **kwargs, ) + return compiled_aten_module elif target_ir == _IRType.torch_compile: - return torch_compile(module, enabled_precisions=enabled_precisions, **kwargs) + return torch_compile( + module, enabled_precisions=enabled_precisions_set, **kwargs + ) else: raise RuntimeError("Module is an unknown format or the ir requested is unknown") -def torch_compile(module, **kwargs): +def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any: """ Returns a boxed model which is the output of torch.compile. This does not compile the model to TRT. Execute this model on @@ -193,12 +226,12 @@ def torch_compile(module, **kwargs): def convert_method_to_trt_engine( module: Any, - method_name: str, - ir="default", - inputs=[], - enabled_precisions=set([torch.float]), - **kwargs, -): + method_name: str = "forward", + inputs: Optional[Sequence[Input | torch.Tensor]] = None, + ir: str = "default", + enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, + **kwargs: Any, +) -> bytes: """Convert a TorchScript module method to a serialized TensorRT engine Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings @@ -229,6 +262,10 @@ def convert_method_to_trt_engine( Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ + enabled_precisions_set = ( + enabled_precisions if enabled_precisions is not None else {torch.float} + ) + module_type = _parse_module_type(module) target_ir = _get_target_ir(module_type, ir) if target_ir == _IRType.ts: @@ -239,11 +276,11 @@ def convert_method_to_trt_engine( "Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript", ) ts_mod = torch.jit.script(module) - return torch_tensorrt.ts.convert_method_to_trt_engine( + return torch_tensorrt.ts.convert_method_to_trt_engine( # type: ignore[no-any-return] ts_mod, - method_name, inputs=inputs, - enabled_precisions=enabled_precisions, + method_name=method_name, + enabled_precisions=enabled_precisions_set, **kwargs, ) elif target_ir == _IRType.fx: diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index 63dffceb9d..44cb772dc3 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -1,2 +1,3 @@ -from torch_tensorrt._C import dtype, EngineCapability, TensorFormat -from tensorrt import DeviceType +from torch_tensorrt._C import EngineCapability, TensorFormat, dtype # noqa: F401 + +from tensorrt import DeviceType # noqa: F401 diff --git a/py/torch_tensorrt/_util.py b/py/torch_tensorrt/_utils.py similarity index 73% rename from py/torch_tensorrt/_util.py rename to py/torch_tensorrt/_utils.py index 6f7e1a6c83..b21696427b 100644 --- a/py/torch_tensorrt/_util.py +++ b/py/torch_tensorrt/_utils.py @@ -1,10 +1,11 @@ -from torch_tensorrt import __version__ -from torch_tensorrt import _C +from typing import Any import torch +from torch_tensorrt import _C +from torch_tensorrt._version import __version__ -def dump_build_info(): +def dump_build_info() -> None: """Prints build information about the torch_tensorrt distribution to stdout""" print(get_build_info()) @@ -15,24 +16,24 @@ def get_build_info() -> str: Returns: str: String containing the build information for torch_tensorrt distribution """ - build_info = _C.get_build_info() - build_info = ( + core_build_info = _C.get_build_info() + build_info = str( "Torch-TensorRT Version: " + str(__version__) + "\n" + "Using PyTorch Version: " + str(torch.__version__) + "\n" - + build_info + + core_build_info ) return build_info -def set_device(gpu_id): +def set_device(gpu_id: int) -> None: _C.set_device(gpu_id) -def sanitized_torch_version() -> str: +def sanitized_torch_version() -> Any: return ( torch.__version__ if ".nv" not in torch.__version__ diff --git a/py/torch_tensorrt/dynamo/_SourceIR.py b/py/torch_tensorrt/dynamo/_SourceIR.py index c0547986c4..8e68ed0f0f 100644 --- a/py/torch_tensorrt/dynamo/_SourceIR.py +++ b/py/torch_tensorrt/dynamo/_SourceIR.py @@ -9,7 +9,7 @@ class SourceIR(Enum): TORCHTRT_LOWERED = auto() UNKNOWN = auto() - def __str__(self): + def __str__(self) -> str: if self == SourceIR.NN: return "nn" elif self == SourceIR.ACC: diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index cf9344d94a..49d8448bff 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -1,13 +1,12 @@ +from torch_tensorrt._utils import sanitized_torch_version + from packaging import version -from torch_tensorrt._util import sanitized_torch_version if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): - from ._settings import * - from .conversion import * - from .aten_tracer import trace - from .conversion.converter_registry import ( - DYNAMO_CONVERTERS, - dynamo_tensorrt_converter, - ) - from .compile import compile - from ._SourceIR import SourceIR + from ._settings import * # noqa: F403 + from ._SourceIR import SourceIR # noqa: F403 + from .aten_tracer import trace # noqa: F403 + from .compile import compile # noqa: F403 + from .conversion import * # noqa: F403 + from .conversion.converter_registry import DYNAMO_CONVERTERS # noqa: F403 + from .conversion.converter_registry import dynamo_tensorrt_converter # noqa: F403 diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 88b2a1e8b2..4c75a38c66 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -8,5 +8,5 @@ MAX_AUX_STREAMS = None VERSION_COMPATIBLE = False OPTIMIZATION_LEVEL = None -USE_PYTHON_RUNTIME = None TRUNCATE_LONG_AND_DOUBLE = False +USE_PYTHON_RUNTIME = False diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 99ec34ec27..e0eef45eb2 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,17 +1,18 @@ from dataclasses import dataclass, field -from typing import Optional, Sequence +from typing import Optional, Set + import torch from torch_tensorrt.dynamo._defaults import ( - PRECISION, DEBUG, - WORKSPACE_SIZE, - MIN_BLOCK_SIZE, - PASS_THROUGH_BUILD_FAILURES, MAX_AUX_STREAMS, - VERSION_COMPATIBLE, + MIN_BLOCK_SIZE, OPTIMIZATION_LEVEL, - USE_PYTHON_RUNTIME, + PASS_THROUGH_BUILD_FAILURES, + PRECISION, TRUNCATE_LONG_AND_DOUBLE, + USE_PYTHON_RUNTIME, + VERSION_COMPATIBLE, + WORKSPACE_SIZE, ) @@ -21,7 +22,7 @@ class CompilationSettings: debug: bool = DEBUG workspace_size: int = WORKSPACE_SIZE min_block_size: int = MIN_BLOCK_SIZE - torch_executed_ops: Sequence[str] = field(default_factory=set) + torch_executed_ops: Set[str] = field(default_factory=set) pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES max_aux_streams: Optional[int] = MAX_AUX_STREAMS version_compatible: bool = VERSION_COMPATIBLE diff --git a/py/torch_tensorrt/dynamo/aten_tracer.py b/py/torch_tensorrt/dynamo/aten_tracer.py index 74c7d151ef..3ff41bac3d 100644 --- a/py/torch_tensorrt/dynamo/aten_tracer.py +++ b/py/torch_tensorrt/dynamo/aten_tracer.py @@ -1,13 +1,12 @@ import copy import sys from contextlib import contextmanager -from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union -from packaging import version +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple import torch import torch._dynamo as torchdynamo - -from torch_tensorrt.fx.utils import req_torch_version +from torch.fx.passes.infra.pass_base import PassResult +from torch_tensorrt.dynamo.utils import req_torch_version from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( compose_bmm, compose_chunk, @@ -23,11 +22,7 @@ ) from typing_extensions import TypeAlias -Value: TypeAlias = Union[ - Tuple["Value", ...], - List["Value"], - Dict[str, "Value"], -] +Value: TypeAlias = Tuple["Value", ...] | List["Value"] | Dict[str, "Value"] class DynamoConfig: @@ -43,7 +38,6 @@ def __init__( specialize_int: bool = True, verbose: bool = True, ) -> None: - self.capture_scalar_outputs = capture_scalar_outputs self.guard_nn_modules = guard_nn_modules self.dynamic_shapes = dynamic_shapes @@ -97,7 +91,7 @@ def dynamo_trace( aten_graph: bool, tracing_mode: str = "real", dynamo_config: Optional[DynamoConfig] = None, -) -> Tuple[torch.fx.GraphModule, Set]: +) -> Any: # Tuple[torch.fx.GraphModule, Set[_guards.Guard]]: """ TODO: Once we fully migrate to torchdynamo frontend, we will remove this config option alltogether. For now, it helps with quick @@ -126,7 +120,11 @@ def dynamo_trace( @req_torch_version("2.dev") -def trace(model, inputs, **kwargs): +def trace( + model: torch.nn.Module | torch.fx.GraphModule, + inputs: Tuple[Any, ...], + **kwargs: Any, +) -> torch.fx.GraphModule: """ Optimized trace with necessary passes which re-compose some ops or replace some ops These passes should be general and functional purpose diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index 596ff92589..aafcc5f9b0 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -1 +1 @@ -from .backends import torch_tensorrt_backend +from .backends import torch_tensorrt_backend # noqa: F401 diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index a89999b930..ca14ad264b 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -1,45 +1,37 @@ import logging -from typing import Sequence -import torch from functools import partial -import torch._dynamo as td +from typing import Any, Callable, Sequence +import torch +import torch._dynamo as td +from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler from torch_tensorrt.dynamo import CompilationSettings -from torch_tensorrt.dynamo.lowering._decompositions import ( - get_decompositions, -) -from torch_tensorrt.dynamo.lowering._pre_aot_lowering import ( - pre_aot_substitutions, -) -from torch_tensorrt.dynamo.lowering._partition import ( - partition, - get_submod_inputs, -) -from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs from torch_tensorrt.dynamo.conversion import ( convert_module, repair_long_or_double_inputs, ) - -from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler - +from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions +from torch_tensorrt.dynamo.lowering._partition import get_submod_inputs, partition +from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions +from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs logger = logging.getLogger(__name__) -@td.register_backend(name="torch_tensorrt") +@td.register_backend(name="torch_tensorrt") # type: ignore[misc] def torch_tensorrt_backend( - gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs -): + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any +) -> torch.nn.Module: DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend - return DEFAULT_BACKEND(gm, sample_inputs, **kwargs) + compiled_mod: torch.nn.Module = DEFAULT_BACKEND(gm, sample_inputs, **kwargs) + return compiled_mod -@td.register_backend(name="aot_torch_tensorrt_aten") +@td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc] def aot_torch_tensorrt_aten_backend( - gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs -): + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any +) -> torch.nn.Module: settings = parse_dynamo_kwargs(kwargs) custom_backend = partial( @@ -63,7 +55,7 @@ def _pretraced_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], settings: CompilationSettings = CompilationSettings(), -): +) -> torch.fx.GraphModule | Callable[..., Any]: """Helper function to manage translation of traced FX module to TRT engines Args: @@ -82,7 +74,7 @@ def _pretraced_backend( settings=settings, ) return trt_compiled - except: + except AssertionError: if not settings.pass_through_build_failures: logger.warning( "TRT conversion failed on the subgraph. See trace above. " @@ -138,6 +130,7 @@ def _compile_module( partitioned_module, submodule, sample_inputs ) + assert submodule_inputs is not None # Handle long/double inputs if requested by the user if settings.truncate_long_and_double: submodule_inputs = repair_long_or_double_inputs( diff --git a/py/torch_tensorrt/dynamo/compile.py b/py/torch_tensorrt/dynamo/compile.py index b27dcb45ee..0402a6af43 100644 --- a/py/torch_tensorrt/dynamo/compile.py +++ b/py/torch_tensorrt/dynamo/compile.py @@ -1,38 +1,36 @@ -import torch -import logging import collections.abc -import torch_tensorrt -from functools import partial +import logging +from typing import Any, List, Optional, Set, Tuple -from typing import Any, Optional, Sequence -from torch_tensorrt import EngineCapability, Device +import torch +import torch_tensorrt from torch.fx.passes.pass_manager import PassManager -from torch.fx.passes.shape_prop import ShapeProp from torch.fx.passes.splitter_base import SplitResult -from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting -from torch_tensorrt.dynamo.lowering import ( - fuse_permute_linear, - fuse_permute_matmul, +from torch_tensorrt._Device import Device +from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum + EngineCapability, ) from torch_tensorrt.dynamo import CompilationSettings -from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device -from torch_tensorrt.dynamo.backend import torch_tensorrt_backend -from torch_tensorrt.dynamo.backend.backends import _compile_module -from torch_tensorrt.dynamo.conversion import convert_module - from torch_tensorrt.dynamo._defaults import ( - PRECISION, DEBUG, - WORKSPACE_SIZE, - MIN_BLOCK_SIZE, - PASS_THROUGH_BUILD_FAILURES, MAX_AUX_STREAMS, - VERSION_COMPATIBLE, + MIN_BLOCK_SIZE, OPTIMIZATION_LEVEL, - USE_PYTHON_RUNTIME, + PASS_THROUGH_BUILD_FAILURES, + PRECISION, TRUNCATE_LONG_AND_DOUBLE, + USE_PYTHON_RUNTIME, + VERSION_COMPATIBLE, + WORKSPACE_SIZE, ) - +from torch_tensorrt.dynamo.backend.backends import _compile_module +from torch_tensorrt.dynamo.conversion import convert_module +from torch_tensorrt.dynamo.lowering._fusers import ( + fuse_permute_linear, + fuse_permute_matmul, +) +from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs +from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting logger = logging.getLogger(__name__) @@ -41,35 +39,37 @@ def compile( gm: Any, inputs: Any, *, - device=Device._current_device(), - disable_tf32=False, - sparse_weights=False, - enabled_precisions=set(), - refit=False, - debug=DEBUG, - capability=EngineCapability.default, - num_avg_timing_iters=1, - workspace_size=WORKSPACE_SIZE, - dla_sram_size=1048576, - dla_local_dram_size=1073741824, - dla_global_dram_size=536870912, - calibrator=None, - truncate_long_and_double=TRUNCATE_LONG_AND_DOUBLE, - require_full_compilation=False, - min_block_size=MIN_BLOCK_SIZE, - torch_executed_ops=[], - torch_executed_modules=[], - pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES, - max_aux_streams=MAX_AUX_STREAMS, - version_compatible=VERSION_COMPATIBLE, - optimization_level=OPTIMIZATION_LEVEL, - use_python_runtime=USE_PYTHON_RUNTIME, - **kwargs, -): + device: Device = Device._current_device(), + disable_tf32: bool = False, + sparse_weights: bool = False, + enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,), + refit: bool = False, + debug: bool = DEBUG, + capability: EngineCapability = EngineCapability.default, + num_avg_timing_iters: int = 1, + workspace_size: int = WORKSPACE_SIZE, + dla_sram_size: int = 1048576, + dla_local_dram_size: int = 1073741824, + dla_global_dram_size: int = 536870912, + calibrator: object = None, + truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE, + require_full_compilation: bool = False, + min_block_size: int = MIN_BLOCK_SIZE, + torch_executed_ops: Optional[List[str]] = None, + torch_executed_modules: Optional[List[str]] = None, + pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, + max_aux_streams: Optional[int] = MAX_AUX_STREAMS, + version_compatible: bool = VERSION_COMPATIBLE, + optimization_level: Optional[int] = OPTIMIZATION_LEVEL, + use_python_runtime: bool = USE_PYTHON_RUNTIME, + **kwargs: Any, +) -> torch.fx.GraphModule: if debug: logger.setLevel(logging.DEBUG) - logger.warn( + enabled_precisions = set(enabled_precisions) + + logger.warning( "The Dynamo backend is an experimental feature, for which only the " + "following arguments are supported: " + "{enabled_precisions, debug, workspace_size, min_block_size, " @@ -79,7 +79,7 @@ def compile( if not isinstance(inputs, collections.abc.Sequence): inputs = [inputs] - torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device)) + _, torch_inputs = prepare_inputs(inputs, prepare_device(device)) if ( torch.float16 in enabled_precisions @@ -104,7 +104,9 @@ def compile( "debug": debug, "workspace_size": workspace_size, "min_block_size": min_block_size, - "torch_executed_ops": torch_executed_ops, + "torch_executed_ops": torch_executed_ops + if torch_executed_ops is not None + else [], "pass_through_build_failures": pass_through_build_failures, "max_aux_streams": max_aux_streams, "version_compatible": version_compatible, @@ -128,9 +130,8 @@ def _compile_graph( split_result: SplitResult, inputs: Any, settings: CompilationSettings = CompilationSettings(), - **kwargs, -): - + **kwargs: Any, +) -> torch.fx.GraphModule: for submod_name, submod_inputs in split_result.submodule_inputs.items(): submod = getattr(split_result.split_module, submod_name) # Only acc submodules will be lowered. @@ -147,7 +148,9 @@ def _compile_graph( return split_result.split_module -def lower_model_using_trt_splitter(model: torch.nn.Module, inputs: Any, **kwargs): +def lower_model_using_trt_splitter( + model: torch.nn.Module, inputs: Any, **kwargs: Any +) -> SplitResult: # Perform basic lowering model = lower_model(model, inputs) splitter_setting = TRTSplitterSetting() @@ -161,12 +164,13 @@ def lower_model_using_trt_splitter(model: torch.nn.Module, inputs: Any, **kwargs return split_result -def lower_model(model: torch.nn.Module, inputs: Any, **kwargs): - +def lower_model( + model: torch.nn.Module, inputs: Any, **kwargs: Any +) -> torch.fx.GraphModule: graph_optimization_pm = PassManager.build_from_passlist( [fuse_permute_matmul, fuse_permute_linear] ) - lowered_model = graph_optimization_pm(model) + lowered_model: torch.fx.GraphModule = graph_optimization_pm(model) # if isinstance(lowered_model, torch.fx.GraphModule): # ShapeProp(lowered_model).propagate(*inputs) diff --git a/py/torch_tensorrt/dynamo/conversion/trt_interpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py similarity index 79% rename from py/torch_tensorrt/dynamo/conversion/trt_interpreter.py rename to py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 4293fb65eb..8026fee686 100644 --- a/py/torch_tensorrt/dynamo/conversion/trt_interpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -1,26 +1,22 @@ import logging import warnings from datetime import datetime -from packaging import version -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set import numpy - -# @manual=//deeplearning/trt/python:py_tensorrt -import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata +from torch_tensorrt._Input import Input +from torch_tensorrt.fx.observer import Observer +from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +from packaging import version from .converter_registry import DYNAMO_CONVERTERS as CONVERTERS -from torch_tensorrt import Input -from torch_tensorrt.fx.observer import Observer -from torch_tensorrt.fx.utils import ( - get_dynamic_dims, - unified_dtype_converter, - Frameworks, -) _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -36,17 +32,18 @@ class TRTInterpreterResult(NamedTuple): serialized_cache: bytearray -class TRTInterpreter(torch.fx.Interpreter): +class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc] def __init__( self, module: torch.fx.GraphModule, input_specs: List[Input], - logger_level=None, - output_dtypes=None, + logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING, + output_dtypes: Optional[List[torch.dtype]] = None, ): super().__init__(module) - self.logger = trt.Logger(logger_level or trt.Logger.WARNING) + # TODO: @narendasan replace with Torch-TensorRT Logger + self.logger = trt.Logger(logger_level) self.builder = trt.Builder(self.logger) flag = 0 @@ -59,12 +56,13 @@ def __init__( missing_ops = self.validate_conversion() if missing_ops: + # TODO: @narendasan make sure to set logging.captureWarnings(True) warnings.warn( "Interpretation will fail due to missing operations \n" + "\n".join(f"{i}" for i in missing_ops) ) - self.optimization_profiles = ( + self.optimization_profiles: Optional[List[trt.IOptimizationProfile]] = ( [self.builder.create_optimization_profile()] if any( input_spec.shape_mode == Input._ShapeMode.DYNAMIC @@ -86,37 +84,37 @@ def __init__( # Data types for TRT Module output Tensors self.output_dtypes = output_dtypes - def validate_conversion(self): - missing_converter = set() + def validate_conversion(self) -> Set[str]: + missing_converters: Set[str] = set() for node in self.module.graph.nodes: if node.op == "call_function" and not CONVERTERS.get(node): - missing_converter.add(f"{node.op} {_get_qualified_name(node.target)}") + missing_converters.add(f"{node.op} {_get_qualified_name(node.target)}") elif node.op == "call_method" and not CONVERTERS.get(node): - missing_converter.add(f"{node.op} torch.Tensor.{node.target}") + missing_converters.add(f"{node.op} torch.Tensor.{node.target}") elif node.op == "call_module": submod = self.fetch_attr(node.target) submod_type = getattr(submod, "_base_class_origin", type(submod)) if not CONVERTERS.get(node): - missing_converter.add(f"{node.op} {torch.typename(submod_type)}") + missing_converters.add(f"{node.op} {torch.typename(submod_type)}") - return missing_converter + return missing_converters def run( self, - workspace_size=0, - precision=torch.float32, - sparse_weights=False, - disable_tf32=False, - force_fp32_output=False, - strict_type_constraints=False, - algorithm_selector=None, - timing_cache=None, - profiling_verbosity=None, - tactic_sources=None, - max_aux_streams=None, - version_compatible=False, - optimization_level=None, + workspace_size: int = 0, + precision: torch.dtype = torch.float32, # TODO: @peri044 Needs to be expanded to set + sparse_weights: bool = False, + disable_tf32: bool = False, + force_fp32_output: bool = False, + strict_type_constraints: bool = False, + algorithm_selector: Optional[trt.IAlgorithmSelector] = None, + timing_cache: Optional[trt.ITimingCache] = None, + profiling_verbosity: Optional[trt.ProfilingVerbosity] = None, + tactic_sources: Optional[int] = None, + max_aux_streams: Optional[int] = None, + version_compatible: bool = False, + optimization_level: Optional[int] = None, ) -> TRTInterpreterResult: """ Build TensorRT engine with some configs. @@ -183,7 +181,7 @@ def run( _LOGGER.info(f"Setting max aux streams to {max_aux_streams}") builder_config.max_aux_streams = max_aux_streams if version_compatible: - _LOGGER.info(f"Using version compatible") + _LOGGER.info("Using version compatible") builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) if optimization_level is not None: _LOGGER.info(f"Using optimization level {optimization_level}") @@ -204,9 +202,10 @@ def run( if strict_type_constraints: builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) - if self.optimization_profiles: - for optimization_profile in self.optimization_profiles: - builder_config.add_optimization_profile(optimization_profile) + if self.optimization_profiles is not None: + if len(self.optimization_profiles) > 0: + for optimization_profile in self.optimization_profiles: + builder_config.add_optimization_profile(optimization_profile) if algorithm_selector: builder_config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE) @@ -232,7 +231,7 @@ def run( engine, self._input_names, self._output_names, serialized_cache ) - def run_node(self, n): + def run_node(self, n: torch.fx.Node) -> torch.fx.Node: self._cur_node_name = str(n) self._cur_node = n # add "_itensor_to_tensor_meta" @@ -241,7 +240,7 @@ def run_node(self, n): n.kwargs = kwargs # run the node - trt_node = super().run_node(n) + trt_node: torch.fx.Node = super().run_node(n) # remove "_itensor_to_tensor_meta" kwargs = dict(n.kwargs) @@ -253,17 +252,20 @@ def run_node(self, n): return trt_node - def placeholder(self, target, args, kwargs): + def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor: self._input_names.append(target) current_input = self.input_specs[self.input_specs_iter] self.input_specs_iter += 1 # Set optimization profile for dynamic input shape - shape = current_input.shape + shape = None if current_input.shape_mode == Input._ShapeMode.DYNAMIC: + assert isinstance(current_input.shape, dict) shape = [] min_shape = current_input.shape["min_shape"] opt_shape = current_input.shape["opt_shape"] max_shape = current_input.shape["max_shape"] + # TODO: Does not support disjoint optimization profiles? + assert self.optimization_profiles is not None self.optimization_profiles[0].set_shape( target, min_shape, opt_shape, max_shape ) @@ -274,6 +276,13 @@ def placeholder(self, target, args, kwargs): else: # -1 to represent the dynamic dimension shape.append(-1) + elif current_input.shape_mode == Input._ShapeMode.STATIC: + assert isinstance(current_input.shape, tuple) + shape = list(current_input.shape) + else: + raise RuntimeError( + f"Unable to access shape spec for input: {target} (got: {current_input})" + ) return self.network.add_input( name=target, @@ -281,7 +290,9 @@ def placeholder(self, target, args, kwargs): dtype=unified_dtype_converter(current_input.torch_dtype, Frameworks.TRT), ) - def call_module(self, target, args, kwargs): + def call_module( + self, target: str, args: Any, kwargs: Any + ) -> Any: # Probably should be Tuple[trt.ITensor]? Case for Any? assert isinstance(target, str) submod = self.fetch_attr(target) submod_type = getattr(submod, "_base_class_origin", type(submod)) @@ -295,7 +306,8 @@ def call_module(self, target, args, kwargs): assert self._cur_node_name is not None return converter(self.network, submod, args, kwargs, self._cur_node_name) - def call_function(self, target, args, kwargs): + def call_function(self, target: str, args: Any, kwargs: Any) -> Any: + # TODO: Why is this stateful? We should be able to take in the inputs converter = CONVERTERS.get(self._cur_node) if not converter: raise RuntimeError( @@ -305,7 +317,7 @@ def call_function(self, target, args, kwargs): assert self._cur_node_name is not None return converter(self.network, target, args, kwargs, self._cur_node_name) - def call_method(self, target, args, kwargs): + def call_method(self, target: str, args: Any, kwargs: Any) -> Any: assert isinstance(target, str) converter = CONVERTERS.get(self._cur_node) @@ -317,7 +329,7 @@ def call_method(self, target, args, kwargs): assert self._cur_node_name is not None return converter(self.network, target, args, kwargs, self._cur_node_name) - def output(self, target, args, kwargs): + def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: assert len(args) == 1 if isinstance(args[0], tuple): outputs = args[0] diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index 16b7f61bca..4536ff0e7b 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -1,4 +1,4 @@ -from .aten_ops_converters import * -from .trt_interpreter import * -from .conversion import * +from ._TRTInterpreter import * # noqa: F403 +from .aten_ops_converters import * # noqa: F403 +from .conversion import * # noqa: F403 from .truncate_long_and_double import repair_long_or_double_inputs diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 6cb3a30abb..240ea47308 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1,21 +1,27 @@ import logging -from typing import Dict, Sequence, Tuple, Union -import torch -import tensorrt as trt -from torch_tensorrt.fx.converters import acc_ops_converters -from .converter_registry import dynamo_tensorrt_converter -from torch.fx.node import Argument, Target, Node +from typing import Any, Dict, Optional, Sequence, Tuple, Union -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +import torch +from torch.fx.node import Argument, Node, Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl -from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor -from torch_tensorrt.dynamo.conversion.converter_utils import cast_int_int_div_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_int_int_div_trt_tensor, + cast_trt_tensor, +) +from torch_tensorrt.fx.converters import acc_ops_converters +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +import tensorrt as trt + +from .converter_registry import dynamo_tensorrt_converter _LOGGER: logging.Logger = logging.getLogger(__name__) -def args_bounds_check(args, i, replacement=None): +def args_bounds_check( + args: Tuple[Argument, ...], i: int, replacement: Optional[Any] = None +) -> Any: return args[i] if len(args) > i else replacement @@ -95,8 +101,7 @@ def aten_ops_div( ) -def embedding_param_validator(embedding_node: Node): - +def embedding_param_validator(embedding_node: Node) -> bool: max_norm = args_bounds_check(embedding_node.args, 2) norm_type = args_bounds_check(embedding_node.args, 3) scale_grad_by_freq = args_bounds_check(embedding_node.args, 4) @@ -223,7 +228,6 @@ def aten_ops_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.activation.relu( network, target, @@ -241,7 +245,6 @@ def aten_ops_rsqrt( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.elementwise.rsqrt( network, target, diff --git a/py/torch_tensorrt/dynamo/conversion/conversion.py b/py/torch_tensorrt/dynamo/conversion/conversion.py index 3b194dd8bf..ae3c8b66b2 100644 --- a/py/torch_tensorrt/dynamo/conversion/conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/conversion.py @@ -1,11 +1,11 @@ -from typing import Sequence, Union -import torch import io -from torch_tensorrt.dynamo.runtime import _PythonTorchTRTModule +from typing import Sequence + +import torch +from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import CompilationSettings -from torch_tensorrt import Input from torch_tensorrt.dynamo.conversion import TRTInterpreter - +from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule import tensorrt as trt @@ -15,7 +15,7 @@ def convert_module( inputs: Sequence[torch.Tensor], settings: CompilationSettings = CompilationSettings(), name: str = "", -): +) -> PythonTorchTensorRTModule | TorchTensorRTModule: """Convert an FX module to a TRT module Args: module: FX GraphModule to convert @@ -23,7 +23,7 @@ def convert_module( settings: Compilation settings name: TRT engine name Returns: - _PythonTorchTRTModule or TorchTensorRTModule + _PythonTorchTensorRTModule or TorchTensorRTModule """ # Specify module output data types to ensure TRT output types agree with # that of the equivalent Torch module @@ -32,7 +32,7 @@ def convert_module( if not isinstance(module_outputs, (list, tuple)): module_outputs = [module_outputs] - output_dtypes = list(output.dtype for output in module_outputs) + output_dtypes = [output.dtype for output in module_outputs] interpreter = TRTInterpreter( module, Input.from_tensors(inputs, disable_memory_format_check=True), @@ -53,10 +53,10 @@ def convert_module( ) if settings.use_python_runtime: - return _PythonTorchTRTModule( + return PythonTorchTensorRTModule( engine=interpreter_result.engine, - input_names=interpreter_result.input_names, - output_names=interpreter_result.output_names, + input_names=list(interpreter_result.input_names), + output_names=list(interpreter_result.output_names), ) else: @@ -68,6 +68,6 @@ def convert_module( return TorchTensorRTModule( serialized_engine=engine_str, name=name, - input_binding_names=interpreter_result.input_names, - output_binding_names=interpreter_result.output_names, + input_binding_names=list(interpreter_result.input_names), + output_binding_names=list(interpreter_result.output_names), ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_registry.py b/py/torch_tensorrt/dynamo/conversion/converter_registry.py index 9bdfc9bf05..7275844500 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_registry.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_registry.py @@ -1,14 +1,36 @@ import logging from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Optional, Sequence, Union, List from enum import Enum, auto +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + Tuple, + Union, + cast, +) -from torch.fx.node import Target, Node, _get_qualified_name +from torch.fx.node import Argument, Node, Target, _get_qualified_name from torch_tensorrt.fx.converter_registry import CONVERTERS - +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor logger = logging.getLogger(__name__) +ConverterImplSignature = Callable[ + [ + TRTNetwork, + Target, + Tuple[Argument, ...], + Dict[str, Argument], + str, + ], + TRTTensor | Sequence[TRTTensor], +] + class ConverterPriority(Enum): """Enum to set a converter's priority in the registry""" @@ -28,7 +50,7 @@ class ConverterSupport: this function must not modify the node or its graph """ - converter_implementation: Callable + converter_implementation: ConverterImplSignature capability_validator: Callable[[Node], bool] = field(default=lambda node: True) @@ -61,7 +83,7 @@ def dynamo_tensorrt_converter( The converter being decorated """ - def register_converter(converter): + def register_converter(converter: ConverterImplSignature) -> ConverterImplSignature: """Helper function to register the converter, then return it""" assert callable(converter), "Converter function must be callable" @@ -95,7 +117,7 @@ def register_converter(converter): return converter - def disable_converter(converter): + def disable_converter(converter: ConverterImplSignature) -> ConverterImplSignature: return converter # Select whether to cache/enable the converter @@ -123,15 +145,17 @@ class ConverterRegistry: def __init__( self, - registries: Sequence[Dict[Target, Union[Callable, Sequence[ConverterSupport]]]], + registries: Sequence[ + Dict[Target, Union[Callable[..., Any], Sequence[ConverterSupport]]] + ], registry_names: Optional[Sequence[str]] = None, ): # Copy reference to each dictionary object into attribute list - self.registries = [registry for registry in registries] + self.registries = list(registries) if registry_names is not None: assert len(self.registries) == len(registry_names) - self.registry_names = [name for name in registry_names] + self.registry_names = list(registry_names) else: self.registry_names = [ f"Registry {i + 1}" for i in range(len(self.registries)) @@ -139,7 +163,7 @@ def __init__( self.validate_invariants() - def validate_invariants(self): + def validate_invariants(self) -> None: """Validates the invariants required of the dictionaries in the registries Raises AssertionError if any invariants have been violated @@ -160,7 +184,11 @@ def validate_invariants(self): else: assert callable(converters), "Converter function must be callable" - def __getitem_without_validation__(self, key: Target): + def __getitem_without_validation__( + self, key: Target + ) -> ( + Any + ): # TODO: Narrow to ConverterImplSignature this when we can remove FX converters """Get the first-found converter in any registry Searches all registries in order and returns the first converter encountered @@ -185,7 +213,11 @@ def __getitem_without_validation__(self, key: Target): raise KeyError(f"None of the converter registries have an entry for {key}") - def __getitem__(self, node: Node): + def __getitem__( + self, node: Node + ) -> ( + Any + ): # TODO: Narrow to ConverterImplSignature this when we can remove FX converters """Get the first-found validated converter in any registry Searches all registries in order and returns the first converter @@ -218,25 +250,33 @@ def __getitem__(self, node: Node): f"None of the converter registries have a validated entry for {key}, with node {node}" ) - def keys(self): + def keys(self) -> Set[Target]: """Get all unique targets across all dictionaries""" return self.unique_targets() - def get_unvalidated(self, key: Target, value=None): + def get_unvalidated( + self, key: Target, value: Optional[ConverterImplSignature] = None + ) -> ( + Any + ): # TODO: Narrow to ConverterImplSignature this when we can remove FX converters """Get unvalidated converter for input target with a default return""" try: return self.__getitem_without_validation__(key) except KeyError: return value - def get(self, node: Node, value=None): + def get( + self, node: Node, value: Optional[ConverterImplSignature] = None + ) -> ( + Any + ): # TODO: Narrow to ConverterImplSignature this when we can remove FX converters """Get validated converter for input node with a default return""" try: return self.__getitem__(node) except KeyError: return value - def __contains__(self, key: Union[Target, Node]): + def __contains__(self, key: Target | Node) -> bool: """Check whether a converter for an input node or target exists""" try: # Attempt to access the item in the registry @@ -251,7 +291,9 @@ def __contains__(self, key: Union[Target, Node]): def get_all_converters_with_target( self, key: Target, return_registry_info: bool = False - ): + ) -> Tuple[ + List[Any], Dict[str, int] | None + ]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters """Get all converters across all registries for the target Returns a list of all converterts having the specified target @@ -283,25 +325,25 @@ def get_all_converters_with_target( if return_registry_info: return converters_with_target, registry_data else: - return converters_with_target + return converters_with_target, None - def __setitem__(self, key, value): + def __setitem__(self, key: Any, value: Any) -> None: raise AssertionError( - f"Do not set registry members directly through the ConverterRegistry object. " + "Do not set registry members directly through the ConverterRegistry object. " + f"Attempted to set {key}: {value} via direct assignment to ConverterRegistry." ) - def __delitem__(self, key): + def __delitem__(self, key: Any) -> None: raise AssertionError( - f"Do not delete registry members directly through the ConverterRegistry object. " + "Do not delete registry members directly through the ConverterRegistry object. " + f"Attempted to delete {key} via direct del on ConverterRegistry." ) - def __len__(self): + def __len__(self) -> int: """Returns the sum of lengths of all registries stored""" return sum(len(registry) for registry in self.registries) - def unique_targets(self): + def unique_targets(self) -> Set[Target]: """Returns the set of unique converter targets stored across all registries""" return set.union(*[set(registry.keys()) for registry in self.registries]) @@ -311,9 +353,9 @@ def qualified_name_or_str(self, target: Target) -> str: if isinstance(target, str): return target else: - return _get_qualified_name(target) + return cast(str, _get_qualified_name(target)) - def get_converter_support_info(self) -> Dict[str, Dict[str, int]]: + def get_converter_support_info(self) -> Dict[str, Optional[Dict[str, int]]]: """Returns a dictionary of targets backed by at least one converter""" available_converters = {} for target in sorted( @@ -330,7 +372,7 @@ def display_all_available_converters(self) -> str: available_converters = "Available converters in ATen registries with counts:\n" support_info = self.get_converter_support_info() - for target, registry_data in support_info.keys(): + for target, registry_data in support_info.items(): available_converters += f"Node: {self.qualified_name_or_str(target)} - Registry Presence Counts: {registry_data}\n" return available_converters @@ -339,6 +381,6 @@ def display_all_available_converters(self) -> str: # Initialize dynamo converter registry with the FX and Dynamo aten registries # Note the Dynamo registry is listed first, for precedence DYNAMO_CONVERTERS: ConverterRegistry = ConverterRegistry( - [DYNAMO_ATEN_CONVERTERS, CONVERTERS], + [DYNAMO_ATEN_CONVERTERS, CONVERTERS], # type: ignore[list-item] ["Dynamo ATen Converters Registry", "FX ATen Converters Registry"], ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 584e15b263..243c2675df 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,18 +1,13 @@ -import torch - -from torch_tensorrt.fx.types import ( - TRTDataType, - TRTNetwork, - TRTTensor, -) +from typing import List +import torch from torch_tensorrt.fx.converters.converter_utils import ( - unified_dtype_converter, Frameworks, + unified_dtype_converter, ) +from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor import tensorrt as trt -from typing import List def dynamic_unsupported(node: torch.fx.Node) -> bool: @@ -93,7 +88,7 @@ def cast_int_int_div_trt_tensor( ): lhs_val = cast_trt_tensor(network, lhs_val, trt.float32, name) rhs_val = cast_trt_tensor(network, rhs_val, trt.float32, name) - return list((lhs_val, rhs_val)) + return [lhs_val, rhs_val] def broadcastable( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index db6e405978..b402240b84 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -1,14 +1,17 @@ from torch_tensorrt.fx.converters.impl import convolution -from . import condition -from . import elementwise -from . import embedding -from . import normalization -from . import slice -from . import unary -from . import activation -from . import matmul -from . import select -from . import shape -from . import squeeze -from . import unsqueeze -from . import permutation + +from . import ( + activation, + condition, + elementwise, + embedding, + matmul, + normalization, + permutation, + select, + shape, + slice, + squeeze, + unary, + unsqueeze, +) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation.py b/py/torch_tensorrt/dynamo/conversion/impl/activation.py index 6a15454f54..0190768223 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/activation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/activation.py @@ -1,24 +1,20 @@ -import numpy as np -from typing import Any, Optional import math +from typing import Any, Optional, Tuple -import tensorrt as trt +import numpy as np import torch +from torch import Tensor from torch.fx.node import Target - from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.fx.converters.impl.activation import * from torch_tensorrt.fx.converters.converter_utils import ( + get_trt_plugin, mark_as_int8_layer, set_layer_name, - get_trt_plugin, ) +from torch_tensorrt.fx.converters.impl.activation import * # noqa: F403 +from torch_tensorrt.fx.types import TRTNetwork, TRTPluginFieldCollection, TRTTensor -from torch_tensorrt.fx.types import ( - TRTNetwork, - TRTTensor, - TRTPluginFieldCollection, -) +import tensorrt as trt def gelu( @@ -28,7 +24,7 @@ def gelu( name: str, input_val: TRTTensor, alpha: Optional[Any] = None, -): +) -> TRTTensor: approximate = alpha if approximate is not None: raise RuntimeError("GeLU converter currently doesn't support fast gelu compute") @@ -53,7 +49,9 @@ def gelu( layer = network.add_plugin_v2([input_val], plugin) - def gelu_dyn_range_fn(dyn_range): + def gelu_dyn_range_fn( + dyn_range: Tuple[Tensor, Tensor] + ) -> Tuple[Tensor, Tensor]: # TODO: This probably will not work with fake tensor return ( dyn_range[0] * 0.5 * (1.0 + torch.erf(dyn_range[0] / math.sqrt(2.0))) ), (dyn_range[1] * 0.5 * (1.0 + torch.erf(dyn_range[0] / math.sqrt(2.0)))) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py index 803e60a2b9..1b46f106b8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py @@ -1,19 +1,18 @@ from typing import Optional - -import tensorrt as trt import torch from torch.fx.node import Target - -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion.converter_utils import broadcastable +from torch_tensorrt.dynamo.conversion.impl.slice import expand from torch_tensorrt.fx.converters.converter_utils import ( broadcast, get_trt_tensor, set_layer_name, ) -from torch_tensorrt.dynamo.conversion.impl.slice import expand +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +import tensorrt as trt def where( @@ -36,7 +35,7 @@ def where( assert type(other) is torch.Tensor, f"value {other} is not torch.Tensor!" if not (broadcastable(input, other)): - assert f"The two torch tensors should be broadcastable" + assert "The two torch tensors should be broadcastable" # get output shape # purpose of this is to bring input and other rank same as @@ -66,7 +65,7 @@ def where( condition_val = condition_layer.get_output(0) else: assert condition.dtype == trt.bool, "mask dtype is not bool!" - if condition_shape != condition_dim: + if condition_shape != condition_dim: # TODO: What is this checking? condition_val = expand( network, target, source_ir, f"{name}_expand", condition, output_shape ) @@ -74,7 +73,7 @@ def where( condition_val = condition if type(input) != TRTTensor: - if x_shape != input_dim: + if x_shape != input_dim: # TODO: What is this checking? # special case where 1 element in input if len(input.shape) == 0: input = input.unsqueeze(0) @@ -96,7 +95,7 @@ def where( y_val = get_trt_tensor(network, other, f"{name}_y") else: y_val = other - if y_shape != other_dim: + if y_shape != other_dim: # TODO: What is this checking? y_val = expand( network, target, source_ir, f"{name}_y_expand", y_val, output_shape ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/__init__.py index 25d71e3702..6965f89636 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/__init__.py @@ -1,2 +1 @@ from .ops import * -from .clamp import clamp diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index a8e4067493..c4cc744aa9 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -1,26 +1,21 @@ import operator import warnings -from typing import Union, Callable, Any, Optional +from typing import Any, Callable, Optional, Union -import tensorrt as trt import torch from torch.fx.node import Target - -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp -from torch_tensorrt.fx.utils import ( - unified_dtype_converter, - Frameworks, -) from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.dynamo.conversion.converter_utils import ( - cast_trt_tensor, -) +from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor from torch_tensorrt.fx.converters.converter_utils import ( - set_layer_name, broadcast, - squeeze_left, get_trt_tensor, + set_layer_name, + squeeze_left, ) +from torch_tensorrt.fx.types import TRTElementWiseOp, TRTNetwork, TRTTensor +from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter + +import tensorrt as trt def get_python_op_from_trt_elementwise_op( @@ -158,5 +153,6 @@ def convert_binary_elementwise( layer = network.add_elementwise(lhs_val, rhs_val, op_type) set_layer_name(layer, target, name, source_ir) output = layer.get_output(0) - output.name = output.name + "_" + target.__name__ + kind: str = str(target.__name__) if callable(target) else target + output.name = output.name + "_" + kind return output diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/clamp.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/clamp.py deleted file mode 100644 index 8fc9df586c..0000000000 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/clamp.py +++ /dev/null @@ -1,78 +0,0 @@ -import numpy as np -from typing import Optional -import tensorrt as trt -from torch.fx.node import Target - -from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.fx.utils import ( - unified_dtype_converter, - Frameworks, -) - -from torch_tensorrt.fx.converters.converter_utils import ( - set_layer_name, - squeeze_left, - get_trt_tensor, -) - -from torch_tensorrt.fx.types import ( - TRTNetwork, - TRTTensor, -) - - -def add_clamp(network, input, val, op, name): - if not len(input.shape): - # clamping scalar - acc_ops_clamp_trt = get_trt_tensor( - network, - squeeze_left( - np.array( - [val], dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY) - ) - ), - f"{name}_clamp_{val}", - ) - else: - acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions - acc_ops_clamp_tensor = np.full( - acc_ops_clamp_shape, - val, - dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY), - ) - acc_ops_clamp_trt = network.add_constant( - acc_ops_clamp_shape, acc_ops_clamp_tensor - ).get_output(0) - layer = network.add_elementwise(input, acc_ops_clamp_trt, op) - return layer - - -def clamp( - network: TRTNetwork, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input_val, - min_val=None, - max_val=None, -) -> TRTTensor: - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"Clamp received input {input_val} that is not part " - "of the TensorRT region!" - ) - - if min_val is not None: - clamp_min_layer = add_clamp( - network, input_val, min_val, trt.ElementWiseOperation.MAX, name - ) - set_layer_name(clamp_min_layer, target, f"{name}_clamp_min") - input_val = clamp_min_layer.get_output(0) - if max_val is not None: - clamp_max_layer = add_clamp( - network, input_val, max_val, trt.ElementWiseOperation.MIN, name - ) - set_layer_name(clamp_max_layer, target, f"{name}_clamp_max") - input_val = clamp_max_layer.get_output(0) - - return input_val diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index a1ec485c31..0ae27e0933 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -1,21 +1,22 @@ -from typing import Any, Optional +from typing import Optional -import tensorrt as trt +import numpy as np from torch.fx.node import Target - -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor -from torch_tensorrt.fx.utils import ( - unified_dtype_converter, - Frameworks, -) from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.fx.converters.converter_utils import get_trt_tensor - from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( convert_binary_elementwise, ) -from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary from torch_tensorrt.dynamo.conversion.impl.unary import sign +from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary +from torch_tensorrt.fx.converters.converter_utils import ( + get_trt_tensor, + set_layer_name, + squeeze_left, +) +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter + +import tensorrt as trt def trunc_div( @@ -116,7 +117,6 @@ def rsqrt( name: str, input: TRTTensor, ) -> TRTTensor: - sqrt_trt_output = convert_unary( network, target, @@ -175,3 +175,68 @@ def fmod( prod_value, ) return sub_value + + +def clamp( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + min_val: Optional[float] = None, + max_val: Optional[float] = None, +) -> TRTTensor: + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"Clamp received input {input_val} that is not part " + "of the TensorRT region!" + ) + + def _add_layer( + network: TRTNetwork, + input: TRTTensor, + val: float, + op: trt.ElementWiseOperation, + name: str, + ) -> ( + trt.ILayer + ): # TODO: Simplify and merge implementations, should just be max and min stacked + if not len(input.shape): + # clamping scalar + acc_ops_clamp_trt = get_trt_tensor( + network, + squeeze_left( + np.array( + [val], + dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY), + ) + ), + f"{name}_clamp_{val}", + ) + else: + acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions + acc_ops_clamp_tensor = np.full( + acc_ops_clamp_shape, + val, + dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY), + ) + acc_ops_clamp_trt = network.add_constant( + acc_ops_clamp_shape, acc_ops_clamp_tensor + ).get_output(0) + layer = network.add_elementwise(input, acc_ops_clamp_trt, op) + return layer + + if min_val is not None: + clamp_min_layer = _add_layer( + network, input_val, min_val, trt.ElementWiseOperation.MAX, name + ) + set_layer_name(clamp_min_layer, target, f"{name}_clamp_min") + input_val = clamp_min_layer.get_output(0) + if max_val is not None: + clamp_max_layer = _add_layer( + network, input_val, max_val, trt.ElementWiseOperation.MIN, name + ) + set_layer_name(clamp_max_layer, target, f"{name}_clamp_max") + input_val = clamp_max_layer.get_output(0) + + return input_val diff --git a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py index 7e914a1d89..48d5b55d7e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py @@ -1,20 +1,10 @@ -import operator -import warnings -from typing import Optional, cast, Any +from typing import Optional -import numpy as np - -import tensorrt as trt import torch from torch.fx.node import Target - from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.fx.converters.converter_utils import get_trt_tensor, set_layer_name from torch_tensorrt.fx.types import TRTNetwork, TRTTensor -from torch_tensorrt.fx.converters.converter_utils import ( - set_layer_name, -) - -from torch_tensorrt.fx.converters.converter_utils import get_trt_tensor def embedding( @@ -29,7 +19,6 @@ def embedding( scale_grad_by_freq: bool, sparse: bool, ) -> TRTTensor: - if network.has_implicit_batch_dimension: raise RuntimeError( "The `embedding` function should be called with explicit batch dimension." diff --git a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py index 9907d3e40d..3e1bef66ef 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py @@ -1,20 +1,16 @@ from typing import Optional - -import tensorrt as trt from torch.fx.node import Target - -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor -from torch_tensorrt.fx.utils import ( - unified_dtype_converter, - Frameworks, -) from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.fx.converters.converter_utils import ( - get_trt_tensor, broadcast, + get_trt_tensor, set_layer_name, ) +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter + +import tensorrt as trt def matrix_multiply( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 70f71055d1..2ab74ef86b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -1,32 +1,25 @@ -from typing import cast, Union, Any, Optional, Sequence +import logging +from typing import Any, List, Optional, Sequence, Union, cast import numpy as np - -import tensorrt as trt import torch from torch.fx.node import Target - -import logging - -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor -from torch_tensorrt.fx.utils import get_dynamic_dims from torch_tensorrt.dynamo._SourceIR import SourceIR - +from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( + convert_binary_elementwise, +) +from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary from torch_tensorrt.fx.converters.converter_utils import ( + get_positive_dim, get_trt_plugin, + has_dynamic_shape, set_layer_name, to_numpy, - has_dynamic_shape, - get_positive_dim, -) - -from torch_tensorrt.dynamo.conversion.impl.unary.base import ( - convert_unary, ) +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.utils import get_dynamic_dims -from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( - convert_binary_elementwise, -) +import tensorrt as trt _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -43,9 +36,8 @@ def batch_norm( running_var: torch.Tensor, training: torch.Tensor, momentum: torch.Tensor, - eps: list, + eps: List[float], ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if not isinstance(input, TRTTensor): raise RuntimeError( f"BatchNorm2d received input {input} that is not part " @@ -55,14 +47,11 @@ def batch_norm( if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm." - scale = cast(torch.Tensor, to_numpy(cast(torch.Tensor, weight))) / np.sqrt( - cast(torch.Tensor, to_numpy(cast(torch.Tensor, running_var))) + cast(float, eps) + scale = cast(torch.Tensor, to_numpy(weight)) / np.sqrt( + cast(torch.Tensor, to_numpy(running_var)) + cast(float, eps) ) - bias = ( - to_numpy(cast(torch.Tensor, bias)) - - to_numpy(cast(torch.Tensor, running_mean)) * scale - ) + bias = to_numpy(bias) - to_numpy(running_mean) * scale power = np.ones_like(scale) # For BatchNorm1d, reshape 1d to 2d @@ -101,10 +90,10 @@ def layer_norm( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - normalized_shape: list, + normalized_shape: List[int], weight: torch.Tensor, bias: torch.Tensor, - eps: list, + eps: List[float], ) -> Union[TRTTensor, Sequence[TRTTensor]]: if not isinstance(input, trt.tensorrt.ITensor): raise RuntimeError( @@ -120,13 +109,13 @@ def layer_norm( "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 ) try: - normalized_shape = np.array(normalized_shape, dtype=np.int32) + normalized_shape_arr = np.array(normalized_shape, dtype=np.int32) except TypeError: _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []") - normalized_shape = np.array([], dtype=np.int32) + normalized_shape_arr = np.array([], dtype=np.int32) normalized_shape_filed = trt.PluginField( - "normalized_shape", normalized_shape, trt.PluginFieldType.INT32 + "normalized_shape", normalized_shape_arr, trt.PluginFieldType.INT32 ) field_collection = trt.PluginFieldCollection( [gamma_field, beta_field, eps_field, normalized_shape_filed] @@ -155,22 +144,21 @@ def layer_norm_no_plugin( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - normalized_shape: list, + normalized_shape: List[int], weight: torch.Tensor, bias: torch.Tensor, - eps: list, + eps: List[float], ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if not isinstance(input, TRTTensor): raise RuntimeError( f"LayerNorm received input {input} that is not part " "of the TensorRT region!" ) - shape = weight.shape # type: ignore[union-attr] + shape = weight.shape broadcasted_shape = (1,) * (len(input.shape) - len(shape)) + shape - gamma = to_numpy(weight.reshape(*shape)) # type: ignore[union-attr] - beta = to_numpy(bias.reshape(*shape)) # type: ignore[union-attr] + gamma = to_numpy(weight.reshape(*shape)) + beta = to_numpy(bias.reshape(*shape)) axes = 0 for d in range(len(shape)): @@ -247,10 +235,14 @@ def layer_norm_no_plugin( ) assert gamma is not None - gamma_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(gamma))) # type: ignore[attr-defined] + gamma_tensor = network.add_constant( + gamma.shape, trt.Weights(np.ascontiguousarray(gamma)) + ) gamma_tensor.name = f"{name}_gamma" assert beta is not None - beta_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(beta))) # type: ignore[attr-defined] + beta_tensor = network.add_constant( + gamma.shape, trt.Weights(np.ascontiguousarray(beta)) + ) beta_tensor.name = f"{name}_beta" # y * gamma + beta scale_layer = convert_binary_elementwise( @@ -281,7 +273,7 @@ def softmax( input: TRTTensor, dim: Optional[Any] = None, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr] + input_ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) if not isinstance(input, TRTTensor): raise RuntimeError( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index 34fc36de0e..4ab7e31bc5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -1,14 +1,12 @@ -from typing import Optional, Sequence, cast - +from typing import Optional, Sequence from torch.fx.node import Target - -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.fx.converters.converter_utils import ( - set_layer_name, get_positive_dim, + set_layer_name, ) +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor def permute( @@ -24,9 +22,7 @@ def permute( f"permute received input {input} that is not a TensorRT ITensor" ) - permutation = [ - get_positive_dim(i, len(input.shape)) for i in cast(Sequence[int], permutation) - ] + permutation = [get_positive_dim(i, len(input.shape)) for i in permutation] layer = network.add_shuffle(input) layer.second_transpose = tuple(permutation) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index ae8a72592a..5cd679b6a6 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -2,15 +2,14 @@ import numpy as np from torch.fx.node import Target - -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape from torch_tensorrt.fx.converters.converter_utils import ( get_positive_dim, has_dynamic_shape, to_numpy, ) -from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape +from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor def select( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py index aff161c560..86c126552b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -1,29 +1,24 @@ -from typing import Union +from typing import List, Optional, Tuple import numpy as np - -import tensorrt as trt import torch from torch.fx.node import Target - -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.fx.converters.converter_utils import ( - set_layer_name, - to_numpy, -) - from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( convert_binary_elementwise, ) +from torch_tensorrt.fx.converters.converter_utils import set_layer_name, to_numpy +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +import tensorrt as trt def get_shape_with_dynamic_shape( network: TRTNetwork, target: Target, - source_ir: SourceIR, + source_ir: Optional[SourceIR], name: str, - shape: Union[list, tuple, torch.Tensor], + shape: List[int] | Tuple[int, ...] | torch.Tensor, input_val: TRTTensor, ) -> TRTTensor: """ diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py index 71a31b746e..57e72803a8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py @@ -1,15 +1,13 @@ from typing import Optional from torch.fx.node import Target - -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape from torch_tensorrt.fx.converters.converter_utils import ( has_dynamic_shape, set_layer_name, ) - -from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape +from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor def slice( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 1405486b4a..3835253219 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -1,20 +1,19 @@ -from typing import Optional, cast import math +from typing import Optional from torch.fx.node import Target - -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion.impl.slice.base import slice from torch_tensorrt.fx.converters.converter_utils import ( - get_positive_dim, - has_dynamic_shape, broadcast, + get_positive_dim, get_trt_tensor, + has_dynamic_shape, ) -from torch_tensorrt.dynamo.conversion.impl.slice.base import slice +from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor -def slice_op( +def slice_op( # TODO: This should be slice not whatever is in base network: TRTNetwork, target: Target, source_ir: Optional[SourceIR], @@ -32,7 +31,7 @@ def slice_op( ) ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) - dim = get_positive_dim(cast(int, dim), ranks) + dim = get_positive_dim(dim, ranks) dynamic_shape = has_dynamic_shape(input.shape) if network.has_implicit_batch_dimension: if dim == 0: @@ -44,19 +43,21 @@ def slice_op( if dynamic_shape: # Check whether slice target dim is dynamic shape dim assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" - start_int = cast(int, start) - stop_int = cast(int, stop) + start_int = start + stop_int = stop if stop_int == 2**63 - 1: stop_int = input.shape[dim] - step_int = cast(int, step) - start = [0] * len(input.shape) - start[dim] = start_int - stride = [1] * len(start) - stride[dim] = step_int + step_int = step + start_slice = [0] * len(input.shape) + start_slice[dim] = start_int + stride_slice = [1] * len(start_slice) + stride_slice[dim] = step_int output_shape = list(input.shape) output_shape[dim] = math.ceil((stop_int - start_int) / step_int) - return slice(network, target, source_ir, name, input, start, output_shape, stride) + return slice( + network, target, source_ir, name, input, start_slice, output_shape, stride_slice + ) def expand( @@ -88,9 +89,9 @@ def expand( ranks = len(shape) inshape = tuple(input_val.shape) - shape = tuple(shape) + shape_t = tuple(shape) start = tuple([0] * ranks) stride = tuple( [int(i == o) for i, o in zip(inshape, shape)] ) # stride == 1 if dimensions match, 0 otherwise - return slice(network, target, source_ir, name, input_val, start, shape, stride) + return slice(network, target, source_ir, name, input_val, start, shape_t, stride) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py index 16a086754e..46e0620590 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py @@ -1,14 +1,12 @@ -from typing import Optional, cast, Any +from typing import Any, Optional, cast from torch.fx.node import Target - -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.fx.converters.converter_utils import ( get_positive_dim, set_layer_name, ) - +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from torch_tensorrt.fx.utils import get_dynamic_dims @@ -38,7 +36,6 @@ def squeeze( assert not (len(dims) == 0), "We don't support dim=None right now for squeeze." for dim in dims: - dim = cast(Optional[int], dim) dim = get_positive_dim( dim, len(input.shape) + (1 if network.has_implicit_batch_dimension else 0), diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/base.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/base.py index b738b05591..4c5011eeec 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/base.py @@ -1,14 +1,11 @@ from typing import Optional -import tensorrt as trt from torch.fx.node import Target - -from torch_tensorrt.fx.types import ( - TRTNetwork, - TRTTensor, -) from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +import tensorrt as trt def convert_unary( @@ -40,5 +37,6 @@ def convert_unary( layer = network.add_unary(input_val, operation_type) set_layer_name(layer, target, name, source_ir) output = layer.get_output(0) - output.name = output.name + "_" + target.__name__ + kind: str = str(target.__name__) if callable(target) else target + output.name = output.name + "_" + kind return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index c1d490104d..22376deedd 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -1,20 +1,14 @@ from typing import Optional -import tensorrt as trt from torch.fx.node import Target - -from torch_tensorrt.fx.types import ( - TRTNetwork, - TRTTensor, -) - from torch_tensorrt.dynamo._SourceIR import SourceIR - - from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( convert_binary_elementwise, ) from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +import tensorrt as trt def sign( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py index d67f790701..b16fee1eec 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -1,15 +1,13 @@ from typing import Optional, cast from torch.fx.node import Target - -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.fx.converters.converter_utils import ( get_positive_dim, get_trt_tensor, set_layer_name, ) - +from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor from torch_tensorrt.fx.utils import get_dynamic_dims @@ -18,8 +16,8 @@ def unsqueeze( target: Target, source_ir: Optional[SourceIR], name: str, - input_t, - dim, + input_t: TRTTensor, + dim: Shape, ) -> TRTTensor: input_val = get_trt_tensor(network, input_t, f"{name}_input_t") if not isinstance(input_val, TRTTensor): diff --git a/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py b/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py index fc3263de57..5cb5d118be 100644 --- a/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py +++ b/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py @@ -1,10 +1,11 @@ +from typing import Optional, Sequence, Set + import torch from torch.fx.node import _get_qualified_name -from typing import Optional, Sequence, Union def _extract_downstream_get_nodes( - module_node: torch.fx.Node, output_indices: Sequence[int] + module_node: torch.fx.Node, output_indices: Set[int] ) -> Sequence[torch.fx.Node]: """Extracts downstream users of a node which get the item at a particular index @@ -35,9 +36,9 @@ def _repair_64bit_input( gm: torch.fx.GraphModule, position: int, submodule_name: str, - submodule_outputs: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]], + submodule_outputs: Optional[torch.Tensor | Sequence[torch.Tensor]], dtype: torch.dtype, -): +) -> None: """Fixes a single Long/Double input to a TRT-accelerated subgraph In-Place modifies the provided graph @@ -90,14 +91,15 @@ def _repair_64bit_input( module_node.replace_input_with(node_64bit, node_32bit) output_positions_64bit = set() - outputs_list = ( - [submodule_outputs] - if isinstance(submodule_outputs, torch.Tensor) - else submodule_outputs - ) # Determine if any outputs of the model are 64-bit type and store their indices if submodule_outputs is not None: + outputs_list = ( + [submodule_outputs] + if isinstance(submodule_outputs, torch.Tensor) + else submodule_outputs + ) + for output_position, output in enumerate(outputs_list): if output.dtype == dtype_64bit: output_positions_64bit.add(output_position) @@ -199,9 +201,11 @@ def repair_long_or_double_inputs( # Repair submodule inputs in accordance with inserted casts dtype_32bit = torch.int32 if (param.dtype == torch.int64) else torch.float32 submodule_inputs = ( - submodule_inputs[:position] - + (param.to(dtype_32bit),) - + submodule_inputs[position + 1 :] + list(submodule_inputs[:position]) + + [ + param.to(dtype_32bit), + ] + + list(submodule_inputs[position + 1 :]) ) return submodule_inputs diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index 44c37789a9..0e13125fc6 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -1,10 +1,10 @@ -from ._decompositions import ( - get_decompositions, +from ._decompositions import get_decompositions # noqa: F401 +from ._fusers import * # noqa: F403 +from ._partition import ( # noqa: F401 + DEFAULT_SINGLE_NODE_PARTITIONS, + get_submod_inputs, + partition, ) -from ._pre_aot_lowering import ( - SUBSTITUTION_REGISTRY, - register_substitution, -) -from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS -from .substitutions import * -from ._fusers import * +from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401 +from ._pre_aot_lowering import register_substitution # noqa: F401 +from .substitutions import * # noqa: F403 diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index d56a3a8616..666d04e779 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -1,20 +1,21 @@ -import torch -from torch._decomp import register_decomposition, core_aten_decompositions +from typing import Any, Callable, Dict +import torch +from torch._decomp import OpOverload, core_aten_decompositions, register_decomposition -DECOMPOSITIONS = {**core_aten_decompositions()} +DECOMPOSITIONS: Dict[OpOverload, Callable[..., Any]] = {**core_aten_decompositions()} aten = torch.ops.aten -def replace_inplace_op(aten_op, outplace_op): +def replace_inplace_op(aten_op: OpOverload, outplace_op: OpOverload) -> Any: """Replace inplace operation with functional equivalent Adapted from: https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361 """ - @register_decomposition(aten_op, registry=DECOMPOSITIONS) - def inplace_op(*args, **kwargs): + @register_decomposition(aten_op, registry=DECOMPOSITIONS) # type: ignore[misc] + def inplace_op(*args: Any, **kwargs: Any) -> Any: out = outplace_op(*args, **kwargs) return args[0].copy_(out) @@ -36,46 +37,51 @@ def inplace_op(*args, **kwargs): replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce) -@register_decomposition(aten.std, registry=DECOMPOSITIONS) -def std_replacement(*args, **kwargs) -> torch.Tensor: +@register_decomposition(aten.std, registry=DECOMPOSITIONS) # type: ignore[misc] +def std_replacement(*args: Any, **kwargs: Any) -> torch.Tensor: return torch.sqrt(torch.var(*args, **kwargs)) -@register_decomposition(aten.rsqrt, registry=DECOMPOSITIONS) -def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: +@register_decomposition(aten.rsqrt, registry=DECOMPOSITIONS) # type: ignore[misc] +def rsqrt_replacement(*args: Any, **kwargs: Any) -> torch.Tensor: return torch.reciprocal(torch.sqrt(*args, **kwargs)) -@register_decomposition(aten._unsafe_view, registry=DECOMPOSITIONS) -def unsafe_view_replacement(x: torch.Tensor, *args, **kwargs) -> torch.Tensor: +@register_decomposition(aten._unsafe_view, registry=DECOMPOSITIONS) # type: ignore[misc] +def unsafe_view_replacement(x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: return torch.reshape(x, *args, **kwargs) -@register_decomposition(torch.ops.aten.lift_fresh_copy, registry=DECOMPOSITIONS) +@register_decomposition(torch.ops.aten.lift_fresh_copy, registry=DECOMPOSITIONS) # type: ignore[misc] def lift_fresh_copy_replacement(x: torch.Tensor) -> torch.Tensor: return x -@register_decomposition(aten.alias, registry=DECOMPOSITIONS) +@register_decomposition(aten.alias, registry=DECOMPOSITIONS) # type: ignore[misc] def alias_replacement(x: torch.Tensor) -> torch.Tensor: return x -@register_decomposition(torch.ops.aten.addmm, registry=DECOMPOSITIONS) +@register_decomposition(torch.ops.aten.addmm, registry=DECOMPOSITIONS) # type: ignore[misc] def addmm_replacement( - input_: torch.Tensor, mat1: torch.Tensor, mat2: torch.Tensor, *, beta=1, alpha=1 + input_: torch.Tensor, + mat1: torch.Tensor, + mat2: torch.Tensor, + *, + beta: int = 1, + alpha: int = 1, ) -> torch.Tensor: return torch.add( torch.mul(input_, beta), torch.mul(torch.matmul(mat1, mat2), alpha) ) -@register_decomposition(torch.ops.aten.reciprocal.default, registry=DECOMPOSITIONS) +@register_decomposition(torch.ops.aten.reciprocal.default, registry=DECOMPOSITIONS) # type: ignore[misc] def reciprocal_replacement( input_: torch.Tensor, ) -> torch.Tensor: return torch.div(1, input_) -def get_decompositions(): +def get_decompositions() -> Dict[OpOverload, Callable[..., Any]]: return DECOMPOSITIONS diff --git a/py/torch_tensorrt/dynamo/lowering/_fusers.py b/py/torch_tensorrt/dynamo/lowering/_fusers.py index c845204a65..720e4ab030 100644 --- a/py/torch_tensorrt/dynamo/lowering/_fusers.py +++ b/py/torch_tensorrt/dynamo/lowering/_fusers.py @@ -2,16 +2,26 @@ from torch_tensorrt.fx.tracer.acc_tracer import acc_ops -def check_permute(node: torch.fx.Node): +def check_permute(node: torch.fx.Node) -> bool: ranks = len(node.meta["tensor_meta"].shape) - permutation = list(i % ranks for i in node.kwargs["permutation"]) # type: ignore[union-attr] - allowed_permutation = list(i for i in range(ranks)) + permutation = [i % ranks for i in node.kwargs["permutation"]] + allowed_permutation = list(range(ranks)) allowed_permutation[-1] = ranks - 2 allowed_permutation[-2] = ranks - 1 return permutation == allowed_permutation -def fuse_permute_matmul(gm: torch.fx.GraphModule): +def trt_transposed_matmul( + lhs: torch.Tensor, rhs: torch.Tensor, lhs_transposed: bool, rhs_transposed: bool +) -> torch.Tensor: + if lhs_transposed: + lhs = lhs.transpose(-1, -2) + if rhs_transposed: + rhs = rhs.transpose(-1, -2) + return torch.matmul(lhs, rhs) + + +def fuse_permute_matmul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: """ Fuse pattern like permute + matmul if permute is transposing the last two dimension. """ @@ -45,11 +55,11 @@ def fuse_permute_matmul(gm: torch.fx.GraphModule): def trt_transposed_linear( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor -): +) -> torch.Tensor: return torch.matmul(input.transpose(-1, -2), weight.t()) + bias -def fuse_permute_linear(gm: torch.fx.GraphModule): +def fuse_permute_linear(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: """ Fuse pattern like permute + linear if permute is transposing the last two dimension. """ diff --git a/py/torch_tensorrt/dynamo/lowering/_partition.py b/py/torch_tensorrt/dynamo/lowering/_partition.py index 18d0b5a69d..246549461a 100644 --- a/py/torch_tensorrt/dynamo/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/lowering/_partition.py @@ -1,27 +1,26 @@ import logging -from typing import Dict, List, Optional, Sequence, Set +from typing import Dict, List, Mapping, Optional, Sequence, Set import torch - -from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY -from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.graph_module import GraphModule from torch.fx.node import _get_qualified_name -from torch.fx.passes.operator_support import OperatorSupport - -from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS - +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition +from torch.fx.passes.operator_support import OperatorSupport, SupportDict +from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE +from torch_tensorrt.dynamo.conversion.converter_registry import ( + DYNAMO_CONVERTERS as CONVERTERS, +) +from torch_tensorrt.dynamo.lowering._pre_aot_lowering import SUBSTITUTION_REGISTRY logger = logging.getLogger(__name__) -DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set( +DEFAULT_SINGLE_NODE_PARTITIONS: List[str] = [ _get_qualified_name(to_replace.new_operator) for to_replace in SUBSTITUTION_REGISTRY.values() -) +] -class TRTPartitioner(CapabilityBasedPartitioner): +class TRTPartitioner(CapabilityBasedPartitioner): # type: ignore[misc] """Partitioner to split an FX graph into subgraphs based on operator support Args: @@ -44,7 +43,7 @@ def __init__( allowed_single_node_partition_ops: Optional[ Sequence[str] ] = DEFAULT_SINGLE_NODE_PARTITIONS, - min_block_size=MIN_BLOCK_SIZE, + min_block_size: int = MIN_BLOCK_SIZE, ) -> None: super().__init__( graph_module, @@ -59,7 +58,7 @@ def __init__( def propose_partitions(self) -> List[Partition]: # Propose partitions using the default, then refine the results initial_proposed_partitions = super().propose_partitions() - partitions = {i: part for i, part in enumerate(initial_proposed_partitions)} + partitions = dict(enumerate(initial_proposed_partitions)) # For each partition, determine whether or not the number of computational operators # exceeds the threshold, and if not, remove that partition @@ -103,19 +102,25 @@ def partition_and_fuse(self) -> GraphModule: return fused_gm -class TorchTensorRTOperatorSupport(OperatorSupport): +class TorchTensorRTOperatorSupport(OperatorSupport): # type: ignore[misc] """Class to determine whether operators within a module are supported""" - def __init__(self, support_dict=None, torch_executed_ops=set()): + def __init__( + self, + support_dict: Optional[SupportDict] = None, + torch_executed_ops: Optional[Set[str]] = None, + ): super().__init__(support_dict) # Initialize sets of supported/unsupported operators - self.supported_operators = {} - self.unsupported_operators = {} - self.torch_executed_ops = torch_executed_ops + self.supported_operators: Dict[str, int] = {} + self.unsupported_operators: Dict[str, int] = {} + self.torch_executed_ops: Set[str] = ( + torch_executed_ops if torch_executed_ops is not None else set() + ) def is_node_supported( - self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node + self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node ) -> bool: node_name = ( _get_qualified_name(node.target) @@ -141,7 +146,7 @@ def is_node_supported( return False - def print_support_overview(self, num_trt_blocks: Optional[int] = None): + def print_support_overview(self, num_trt_blocks: Optional[int] = None) -> None: if num_trt_blocks is not None: logger.debug( f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}" @@ -168,7 +173,7 @@ def partition( gm: torch.fx.GraphModule, verbose: bool = True, min_block_size: int = MIN_BLOCK_SIZE, - torch_executed_ops: Sequence[str] = set(), + torch_executed_ops: Optional[Set[str]] = None, ) -> torch.fx.GraphModule: """Partition an FX GraphModule with aten ops into TRT engines Partitioning is based on converter operator support @@ -181,7 +186,11 @@ def partition( Returns: torch.fx.GraphModule """ - supported_ops = TorchTensorRTOperatorSupport(torch_executed_ops=torch_executed_ops) + supported_ops = TorchTensorRTOperatorSupport( + torch_executed_ops=torch_executed_ops + if torch_executed_ops is not None + else set() + ) partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size) # Determine partitions based on user specifications and operator support @@ -199,7 +208,7 @@ def get_submod_inputs( mod: torch.fx.GraphModule, submod: torch.fx.GraphModule, inputs: Sequence[torch.Tensor], -) -> Sequence[torch.Tensor]: +) -> Optional[Sequence[torch.Tensor]]: """Helper function to get inputs to a Torch submodule Args: @@ -209,9 +218,9 @@ def get_submod_inputs( Returns: Sequence of Tensors representing inputs to child module """ - acc_inputs = None + acc_inputs: Optional[Sequence[torch.Tensor]] = None - def get_input(self, inputs): + def get_input(_: torch.fx.GraphModule, inputs: Sequence[torch.Tensor]) -> None: nonlocal acc_inputs acc_inputs = inputs diff --git a/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py index 8a47fc04d2..32250607df 100644 --- a/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py +++ b/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py @@ -1,11 +1,17 @@ -from dataclasses import dataclass -from typing import Any, Callable, Dict, Optional, Type, Union -import torch import logging +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Type, TypeAlias +import torch +from torch._ops import OpOverload +from torch.fx import GraphModule, Node logger = logging.getLogger(__name__) +SubgraphInsertionFnType: TypeAlias = Callable[ + [GraphModule, Node, Optional[torch.nn.Module]], Node +] + @dataclass(frozen=True) class Substitution: @@ -18,22 +24,20 @@ class Substitution: # and returning a replacement node, with type 'call_function', or raising an Error if # incompatibility is detected # Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph - subgraph_insertion_fn: Callable[ - [torch.fx.GraphModule, torch.fx.Node, Optional[torch.nn.Module]], torch.fx.Node - ] + subgraph_insertion_fn: SubgraphInsertionFnType # Dictionary mapping module to Substitution instance SUBSTITUTION_REGISTRY: Dict[ - Union[Type[torch.nn.Module], Callable], Substitution + (Type[torch.nn.Module] | Callable[..., Any]), Substitution ] = dict() def register_substitution( - module_or_function_to_replace: Union[Type[torch.nn.Module], Callable], - new_operator: torch._ops.OpOverload, + module_or_function_to_replace: (Type[torch.nn.Module] | Callable[..., Any]), + new_operator: OpOverload, enabled: bool = True, -) -> Callable[[Any], Any]: +) -> Callable[[SubgraphInsertionFnType], SubgraphInsertionFnType]: """Decorator to register subgraph insertion functions Args: @@ -44,7 +48,9 @@ def register_substitution( torch.fx.GraphModule """ - def enable_substitution(subgraph_insertion_fn): + def enable_substitution( + subgraph_insertion_fn: SubgraphInsertionFnType, + ) -> SubgraphInsertionFnType: """Function for use if substitution is enabled""" replacement = Substitution( new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn @@ -52,14 +58,16 @@ def enable_substitution(subgraph_insertion_fn): SUBSTITUTION_REGISTRY[module_or_function_to_replace] = replacement return subgraph_insertion_fn - def disable_substitution(subgraph_insertion_fn): + def disable_substitution( + subgraph_insertion_fn: SubgraphInsertionFnType, + ) -> SubgraphInsertionFnType: """Function for use if substitution is disabled""" return subgraph_insertion_fn return enable_substitution if enabled else disable_substitution -def pre_aot_substitutions(gm: torch.fx.GraphModule): +def pre_aot_substitutions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: """Perform graph substitutions prior to AOT tracing Args: @@ -92,6 +100,7 @@ def pre_aot_substitutions(gm: torch.fx.GraphModule): # If submodule/function is a member of the substitution registry, replace it if exists_in_registry: try: + assert to_replace is not None replacement = SUBSTITUTION_REGISTRY[to_replace] op, insertion_fn = ( replacement.new_operator, diff --git a/py/torch_tensorrt/dynamo/lowering/substitutions/__init__.py b/py/torch_tensorrt/dynamo/lowering/substitutions/__init__.py index 8d3acc8874..bd348b3e47 100644 --- a/py/torch_tensorrt/dynamo/lowering/substitutions/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/substitutions/__init__.py @@ -1,2 +1,2 @@ -from .maxpool1d import * -from .einsum import * +from .einsum import * # noqa: F403 +from .maxpool1d import * # noqa: F403 diff --git a/py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py b/py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py index e1c7a5d68e..cfcbdce761 100644 --- a/py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py +++ b/py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py @@ -1,31 +1,30 @@ -from typing import Dict, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple + import torch from torch._custom_op.impl import custom_op from torch.fx.node import Argument, Target - +from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution from torch_tensorrt.fx.converter_registry import tensorrt_converter from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTNetwork, TRTTensor -from torch_tensorrt.dynamo.lowering import register_substitution - @custom_op( qualname="tensorrt::einsum", manual_schema="(str equation, Tensor[] tensors) -> Tensor", ) -def einsum(equation, tensors): +def einsum(equation, tensors): # type: ignore[no-untyped-def] # Defines operator schema, name, namespace, and function header ... -@einsum.impl("cpu") -@einsum.impl("cuda") -@einsum.impl_abstract() +@einsum.impl("cpu") # type: ignore[misc] +@einsum.impl("cuda") # type: ignore[misc] +@einsum.impl_abstract() # type: ignore[misc] def einsum_generic( - *args, - **kwargs, -): + *args: Any, + **kwargs: Any, +) -> Any: # Defines a converter implementation for AOT Autograd to use for shape analysis/propagation return torch.einsum( *args, @@ -34,7 +33,7 @@ def einsum_generic( # TODO: @gs-olive Port to dynamo converter -@tensorrt_converter(torch.ops.tensorrt.einsum.default) +@tensorrt_converter(torch.ops.tensorrt.einsum.default) # type: ignore[misc] def aten_ops_einsum( network: TRTNetwork, target: Target, @@ -43,6 +42,7 @@ def aten_ops_einsum( name: str, ) -> TRTTensor: # Defines converter replacing the default operator for this function + assert isinstance(args[1], Sequence) for input_trt in args[1]: if not isinstance(input_trt, TRTTensor): raise RuntimeError(f"Einsum received non-TRTTensor input: {input_trt}") @@ -53,11 +53,11 @@ def aten_ops_einsum( return einsum_layer.get_output(0) -@register_substitution(torch.einsum, torch.ops.tensorrt.einsum) +@register_substitution(torch.einsum, torch.ops.tensorrt.einsum) # type: ignore[misc] def einsum_insertion_fn( gm: torch.fx.GraphModule, node: torch.fx.Node, - _unused: None = None, + submodule: Optional[torch.nn.Module] = None, ) -> torch.fx.Node: equation = node.args[0] @@ -72,7 +72,7 @@ def einsum_insertion_fn( ), f"TRT Einsum currently only supports 1 or 2 Tensors, got {len(inputs)} Tensors" # Ensure the input is formatted as an equation and - new_node = gm.graph.call_function( + new_node: torch.fx.Node = gm.graph.call_function( torch.ops.tensorrt.einsum, args=(equation, inputs), kwargs=node.kwargs, diff --git a/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py b/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py index b3265e7875..6db2664efb 100644 --- a/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py +++ b/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py @@ -1,15 +1,13 @@ -from typing import Dict, Tuple +from typing import Any, Dict, Optional, Tuple + import torch from torch._custom_op.impl import custom_op from torch.fx.node import Argument, Target - +from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution from torch_tensorrt.fx.converter_registry import tensorrt_converter from torch_tensorrt.fx.converters import acc_ops_converters from torch_tensorrt.fx.types import TRTNetwork, TRTTensor -from torch_tensorrt.dynamo.lowering import register_substitution - - # This file serves as an example and a tutorial for excluding custom modules from # torch.compile tracing. Each required step is labeled with a number indicating the # preferable implementation order. @@ -26,7 +24,7 @@ qualname="tensorrt::maxpool1d", manual_schema="(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor", ) -def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode): +def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode): # type: ignore[no-untyped-def] # Defines operator schema, name, namespace, and function header ... @@ -38,13 +36,13 @@ def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode): # is desirable. If the operator to replace is a custom module you've written, then add its Torch # implementation here. Note that the function header to the generic function can have specific arguments # as in the above placeholder -@maxpool1d.impl("cpu") -@maxpool1d.impl("cuda") -@maxpool1d.impl_abstract() +@maxpool1d.impl("cpu") # type: ignore[misc] +@maxpool1d.impl("cuda") # type: ignore[misc] +@maxpool1d.impl_abstract() # type: ignore[misc] def maxpool1d_generic( - *args, - **kwargs, -): + *args: Any, + **kwargs: Any, +) -> Any: # Defines an implementation for AOT Autograd to use for shape analysis/propagation return torch.nn.functional.max_pool1d( *args, @@ -75,10 +73,11 @@ def maxpool1d_generic( def maxpool1d_insertion_fn( gm: torch.fx.GraphModule, node: torch.fx.Node, - submodule: torch.nn.Module, + submodule: Optional[torch.nn.Module], ) -> torch.fx.Node: # Defines insertion function for new node - new_node = gm.graph.call_function( + assert submodule is not None + new_node: torch.fx.Node = gm.graph.call_function( torch.ops.tensorrt.maxpool1d, args=node.args, kwargs={ @@ -99,7 +98,7 @@ def maxpool1d_insertion_fn( # This accelerated implementation should consume the args/kwargs specified in step 3. # One should expect that torch.compile will compress all kwargs into the args field in # the order specified in the schema written in step 1. -@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default) +@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default) # type: ignore[misc] def tensorrt_maxpool1d( network: TRTNetwork, target: Target, diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py similarity index 80% rename from py/torch_tensorrt/dynamo/runtime/_PythonTorchTRTModule.py rename to py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 79070cea31..b5760161a6 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -1,33 +1,37 @@ -from typing import Any, List, Sequence +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import torch +from torch.nn import Module +from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt -import torch -from torch_tensorrt.fx.utils import unified_dtype_converter, Frameworks -class PythonTorchTRTModule(torch.nn.Module): - """PythonTorchTRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine. +class PythonTorchTensorRTModule(Module): # type: ignore[misc] + """PythonTorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine. This module is backed by the Torch-TensorRT runtime and is only compatibile with FX / Dynamo / Python deployments. This module cannot be serialized to torchscript via torch.jit.trace for C++ deployment. """ def __init__( - self, engine=None, input_names=None, output_names=None, cuda_graph_batch_size=-1 + self, + engine: trt.ICudaEngine, + input_names: Optional[List[str]] = None, + output_names: Optional[List[str]] = None, + cuda_graph_batch_size: int = -1, ): - super(PythonTorchTRTModule, self).__init__() - self._register_state_dict_hook(PythonTorchTRTModule._on_state_dict) + super(PythonTorchTensorRTModule, self).__init__() + self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict) self.engine = engine - self.input_names = input_names - self.output_names = output_names + self.input_names = input_names if input_names is not None else [] + self.output_names = output_names if output_names is not None else [] self.cuda_graph_batch_size = cuda_graph_batch_size self.initialized = False + self._initialize() - if engine: - self._initialize() - - def _initialize(self): + def _initialize(self) -> None: self.initialized = True self.context = self.engine.create_execution_context() @@ -57,7 +61,7 @@ def _initialize(self): + len(self.hidden_output_names) ) - self.input_dtypes: Sequence[torch.dtype] = [ + self.input_dtypes = [ unified_dtype_converter( self.engine.get_binding_dtype(idx), Frameworks.TORCH ) @@ -67,7 +71,7 @@ def _initialize(self): tuple(self.engine.get_binding_shape(idx)) for idx in self.input_binding_indices_in_order ] - self.output_dtypes: Sequence[torch.dtype] = [ + self.output_dtypes = [ unified_dtype_converter( self.engine.get_binding_dtype(idx), Frameworks.TORCH ) @@ -79,7 +83,7 @@ def _initialize(self): else tuple() for idx in self.output_binding_indices_in_order ] - self.hidden_output_dtypes: Sequence[torch.dtype] = [ + self.hidden_output_dtypes = [ unified_dtype_converter( self.engine.get_binding_dtype(idx), Frameworks.TORCH ) @@ -92,11 +96,11 @@ def _initialize(self): for idx in self.hidden_output_binding_indices_in_order ] - def _check_initialized(self): + def _check_initialized(self) -> None: if not self.initialized: - raise RuntimeError("PythonTorchTRTModule is not initialized.") + raise RuntimeError("PythonTorchTensorRTModule is not initialized.") - def _on_state_dict(self, state_dict, prefix, local_metadata): + def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None: self._check_initialized() state_dict[prefix + "engine"] = bytearray(self.engine.serialize()) state_dict[prefix + "input_names"] = self.input_names @@ -105,14 +109,14 @@ def _on_state_dict(self, state_dict, prefix, local_metadata): def _load_from_state_dict( self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Any, + strict: Any, + missing_keys: Any, + unexpected_keys: Any, + error_msgs: Any, + ) -> None: engine_bytes = state_dict[prefix + "engine"] logger = trt.Logger() @@ -123,13 +127,13 @@ def _load_from_state_dict( self.output_names = state_dict[prefix + "output_names"] self._initialize() - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() state["engine"] = bytearray(self.engine.serialize()) state.pop("context", None) return state - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: logger = trt.Logger() runtime = trt.Runtime(logger) state["engine"] = runtime.deserialize_cuda_engine(state["engine"]) @@ -137,12 +141,14 @@ def __setstate__(self, state): if self.engine: self.context = self.engine.create_execution_context() - def forward(self, *inputs): - with torch.autograd.profiler.record_function("PythonTorchTRTModule:Forward"): + def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: + with torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:Forward" + ): self._check_initialized() with torch.autograd.profiler.record_function( - "PythonTorchTRTModule:ProcessInputs" + "PythonTorchTensorRTModule:ProcessInputs" ): assert len(inputs) == len( self.input_names @@ -179,7 +185,7 @@ def forward(self, *inputs): ) with torch.autograd.profiler.record_function( - "PythonTorchTRTModule:ProcessOutputs" + "PythonTorchTensorRTModule:ProcessOutputs" ): # create output tensors outputs: List[torch.Tensor] = [] @@ -190,7 +196,7 @@ def forward(self, *inputs): else: shape = tuple(self.context.get_binding_shape(idx)) - output = torch.empty( # type: ignore[call-overload] + output = torch.empty( size=shape, dtype=self.output_dtypes[i], device=torch.cuda.current_device(), @@ -204,7 +210,7 @@ def forward(self, *inputs): else: shape = tuple(self.context.get_binding_shape(idx)) - output = torch.empty( # type: ignore[call-overload] + output = torch.empty( size=shape, dtype=self.hidden_output_dtypes[i], device=torch.cuda.current_device(), @@ -212,7 +218,7 @@ def forward(self, *inputs): bindings[idx] = output.data_ptr() with torch.autograd.profiler.record_function( - "PythonTorchTRTModule:TensorRTRuntime" + "PythonTorchTensorRTModule:TensorRTRuntime" ): if self.engine.has_implicit_batch_dimension: self.context.execute_async( @@ -228,7 +234,7 @@ def forward(self, *inputs): return tuple(outputs) - def enable_profiling(self, profiler: "trt.IProfiler" = None): + def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None: """ Enable TensorRT profiling. After calling this function, TensorRT will report time spent on each layer in stdout for each forward run. @@ -238,7 +244,7 @@ def enable_profiling(self, profiler: "trt.IProfiler" = None): if not self.context.profiler: self.context.profiler = trt.Profiler() if profiler is None else profiler - def disable_profiling(self): + def disable_profiling(self) -> None: """ Disable TensorRT profiling. """ @@ -253,4 +259,7 @@ def get_layer_info(self) -> str: Get layer info of the engine. Only support for TRT > 8.2. """ inspector = self.engine.create_engine_inspector() - return inspector.get_engine_information(trt.LayerInformationFormat.JSON) + engine_json: str = inspector.get_engine_information( + trt.LayerInformationFormat.JSON + ) + return engine_json diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index c1fd86af8a..bddee9b93b 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -1,14 +1,20 @@ import logging -from typing import Any, List, Tuple +from typing import Any, List, Optional, Tuple import torch -from torch_tensorrt import _C from torch_tensorrt._Device import Device logger = logging.getLogger(__name__) +SerializedTensorRTEngineFmt = Tuple[ + str, str, bytes, str, str +] # Defined in //core/runtime/register_jit_hooks.cpp +SerializedTorchTensorRTModuleFmt = Tuple[ + str, SerializedTensorRTEngineFmt, List[str], List[str] +] -class TorchTensorRTModule(torch.nn.Module): + +class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc] """TorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine. This module is backed by the Torch-TensorRT runtime and is fully compatibile with both @@ -30,10 +36,10 @@ class TorchTensorRTModule(torch.nn.Module): def __init__( self, - serialized_engine: bytearray = bytearray(), + serialized_engine: Optional[bytes] = None, name: str = "", - input_binding_names: List[str] = [], - output_binding_names: List[str] = [], + input_binding_names: Optional[List[str]] = None, + output_binding_names: Optional[List[str]] = None, target_device: Device = Device._current_device(), ): """__init__ method for torch_tensorrt.dynamo.runtime._TorchTensorRTModule.TorchTensorRTModule @@ -74,11 +80,15 @@ def __init__( if not isinstance(serialized_engine, bytearray): ValueError("Expected serialized engine as bytearray") - self.input_binding_names = input_binding_names - self.output_binding_names = output_binding_names + self.input_binding_names = ( + input_binding_names if input_binding_names is not None else [] + ) + self.output_binding_names = ( + output_binding_names if output_binding_names is not None else [] + ) self.name = name - if serialized_engine != bytearray(): + if serialized_engine is not None: self.engine = torch.classes.tensorrt.Engine( [ torch.ops.tensorrt.ABI_VERSION(), @@ -92,7 +102,7 @@ def __init__( else: self.engine = None - def get_extra_state(self): + def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt: return ( self.name, self.engine.__getstate__() if self.engine is not None else None, @@ -100,7 +110,7 @@ def get_extra_state(self): self.output_binding_names, ) - def set_extra_state(self, state): + def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: self.name = state[0] if state[1] is not None: serialized_engine_info = state[1][0] @@ -123,7 +133,7 @@ def set_extra_state(self, state): self.input_binding_names = state[2] self.output_binding_names = state[3] - def forward(self, *inputs): + def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: """Implementation of the forward pass for a TensorRT engine Args: @@ -139,28 +149,30 @@ def forward(self, *inputs): self.input_binding_names ), f"Wrong number of inputs, expected {len(self.input_binding_names)} got {len(inputs)}." - types = [issubclass(type(i), torch.Tensor) for i in inputs] + types: List[bool] = [issubclass(type(i), torch.Tensor) for i in inputs] try: assert all(types) - except: + except AssertionError: def is_non_tensor(i: Tuple[Any, bool]) -> bool: return not i[1] - non_tensors = [i[0] for i in filter(zip(inputs, types), is_non_tensor)] + non_tensors = [i[0] for i in filter(is_non_tensor, zip(inputs, types))] raise RuntimeError( f"TorchTensorRTModule expects a flattened list of tensors as input, found non tensors: {non_tensors}" ) - outputs = torch.ops.tensorrt.execute_engine(list(inputs), self.engine) + outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine( + list(inputs), self.engine + ) if len(outputs) == 1: return outputs[0] return tuple(outputs) - def enable_profiling(self, profiling_results_dir: str = None): + def enable_profiling(self, profiling_results_dir: Optional[str] = None) -> None: """Enable the profiler to collect latency information about the execution of the engine Traces can be visualized using https://ui.perfetto.dev/ or compatible alternatives @@ -175,7 +187,7 @@ def enable_profiling(self, profiling_results_dir: str = None): self.engine.profile_path_prefix = profiling_results_dir self.engine.enable_profiling() - def disable_profiling(self): + def disable_profiling(self) -> None: """Disable the profiler""" if self.engine is None: raise RuntimeError("Engine has not been initalized yet.") @@ -192,16 +204,18 @@ def get_layer_info(self) -> str: if self.engine is None: raise RuntimeError("Engine has not been initalized yet.") - return self.engine.get_engine_layer_info() + layer_info: str = self.engine.get_engine_layer_info() + return layer_info - def dump_layer_info(self): + def dump_layer_info(self) -> None: """Dump layer information encoded by the TensorRT engine in this module to STDOUT""" if self.engine is None: raise RuntimeError("Engine has not been initalized yet.") - return self.engine.dump_engine_layer_info() + self.engine.dump_engine_layer_info() @staticmethod def _pack_binding_names(binding_names: List[str]) -> str: delim = torch.ops.tensorrt.SERIALIZED_ENGINE_BINDING_DELIM()[0] - return delim.join(binding_names) + packed_bindings: str = delim.join(binding_names) + return packed_bindings diff --git a/py/torch_tensorrt/dynamo/runtime/__init__.py b/py/torch_tensorrt/dynamo/runtime/__init__.py index a4586eae0a..ee1caab972 100644 --- a/py/torch_tensorrt/dynamo/runtime/__init__.py +++ b/py/torch_tensorrt/dynamo/runtime/__init__.py @@ -1,2 +1,2 @@ -from ._PythonTorchTRTModule import PythonTorchTRTModule -from ._TorchTensorRTModule import TorchTensorRTModule +from ._PythonTorchTensorRTModule import PythonTorchTensorRTModule # noqa: F401 +from ._TorchTensorRTModule import TorchTensorRTModule # noqa: F401 diff --git a/py/torch_tensorrt/dynamo/tools/opset_coverage.py b/py/torch_tensorrt/dynamo/tools/opset_coverage.py index a46236c408..7e1560c855 100644 --- a/py/torch_tensorrt/dynamo/tools/opset_coverage.py +++ b/py/torch_tensorrt/dynamo/tools/opset_coverage.py @@ -1,6 +1,7 @@ import dataclasses import json import os +import warnings from collections import OrderedDict from dataclasses import dataclass from enum import Enum, auto @@ -10,8 +11,8 @@ import torch import torch._prims as prims import torchgen -from torch._ops import OpOverload from torch._dynamo.variables import BuiltinVariable +from torch._ops import OpOverload from torch_tensorrt.dynamo.conversion.converter_registry import ( DYNAMO_CONVERTERS, ConverterRegistry, @@ -110,7 +111,6 @@ def opset_coverage( converter_registry: Optional[ConverterRegistry] = None, decomposition_registry: Optional[Dict[OpOverload, Callable[..., Any]]] = None, ) -> OpsetCoverage: - opset_schemas = dict(opset) opset_targets = set(opset_schemas.keys()) @@ -134,17 +134,20 @@ def opset_coverage( _, registry_data = c_registry.get_all_converters_with_target( target, return_registry_info=True ) - if registry_data["Dynamo ATen Converters Registry"] >= 1: - status = SupportStatus.CONVERTED - support_count += 1 - elif registry_data["FX ATen Converters Registry"] >= 1: - status = SupportStatus.LEGACY_CONVERTED - legacy_count += 1 - - support_status[target_str] = { - "schema": f"{target_str.split('.')[0]}.{opset_schemas[target_str]}", - "status": str(status), - } + if registry_data is not None: + if registry_data["Dynamo ATen Converters Registry"] >= 1: + status = SupportStatus.CONVERTED + support_count += 1 + elif registry_data["FX ATen Converters Registry"] >= 1: + status = SupportStatus.LEGACY_CONVERTED + legacy_count += 1 + + support_status[target_str] = { + "schema": f"{target_str.split('.')[0]}.{opset_schemas[target_str]}", + "status": str(status), + } + else: + warnings.warn(f"No registry data for op: {target_str}") l_registry = ( decomposition_registry diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 0688d2e169..bbb1a4354b 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -1,10 +1,13 @@ -import torch import logging -from dataclasses import replace, fields +from dataclasses import fields, replace +from typing import Any, Callable, Dict, Optional, Sequence + +import torch +from torch_tensorrt._Device import Device +from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import CompilationSettings -from typing import Any, Union, Sequence, Dict -from torch_tensorrt import Input, Device -from typing import Optional + +from packaging import version logger = logging.getLogger(__name__) @@ -27,7 +30,7 @@ def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool # Runtime was not manually specified by the user, automatically detect runtime else: try: - from torch_tensorrt.dynamo.runtime import TorchTensorRTModule + from torch_tensorrt.dynamo.runtime import TorchTensorRTModule # noqa: F401 using_python_runtime = False reason = "since C++ dependency was detected as present" @@ -42,20 +45,22 @@ def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool return using_python_runtime -def cosine_similarity(gt_tensor, pred_tensor): +def cosine_similarity(gt_tensor: torch.Tensor, pred_tensor: torch.Tensor) -> float: gt_tensor = gt_tensor.flatten().to(torch.float32) pred_tensor = pred_tensor.flatten().to(torch.float32) if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0: if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True): return 1.0 - res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6) - res = res.cpu().detach().item() + res_t = torch.nn.functional.cosine_similarity( + gt_tensor, pred_tensor, dim=0, eps=1e-6 + ) + res: float = res_t.cpu().detach().item() return res def prepare_inputs( - inputs: Union[Input, torch.Tensor, Sequence, Dict], + inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any], device: torch.device = torch.device("cuda"), ) -> Any: if isinstance(inputs, Input): @@ -70,36 +75,35 @@ def prepare_inputs( return Input.from_tensor(inputs), inputs elif isinstance(inputs, list): - prepared_input = list() - torchtrt_inputs = [] - torch_inputs = [] + torchtrt_input_list = [] + torch_input_list = [] for input_obj in inputs: torchtrt_input, torch_input = prepare_inputs(input_obj) - torchtrt_inputs.append(torchtrt_input) - torch_inputs.append(torch_input) + torchtrt_input_list.append(torchtrt_input) + torch_input_list.append(torch_input) - return torchtrt_inputs, torch_inputs + return torchtrt_input_list, torch_input_list elif isinstance(inputs, tuple): - torchtrt_inputs = [] - torch_inputs = [] + torchtrt_inputs_tup = [] + torch_inputs_tup = [] for input_obj in inputs: torchtrt_input, torch_input = prepare_inputs(input_obj) - torchtrt_inputs.append(torchtrt_input) - torch_inputs.append(torch_input) + torchtrt_inputs_tup.append(torchtrt_input) + torch_inputs_tup.append(torch_input) - return tuple(torchtrt_inputs), tuple(torch_inputs) + return tuple(torchtrt_inputs_tup), tuple(torch_inputs_tup) elif isinstance(inputs, dict): - torchtrt_inputs = dict() - torch_inputs = dict() + torchtrt_inputs_dict: Dict[Any, Any] = dict() + torch_inputs_dict: Dict[Any, Any] = dict() for key, input_obj in inputs.items(): torchtrt_input, torch_input = prepare_inputs(input_obj) - torchtrt_inputs[key] = torchtrt_input - torch_inputs[key] = torch_input + torchtrt_inputs_dict[key] = torchtrt_input + torch_inputs_dict[key] = torch_input - return torchtrt_inputs, torch_inputs + return torchtrt_inputs_dict, torch_inputs_dict else: raise ValueError( @@ -108,25 +112,26 @@ def prepare_inputs( ) -def prepare_device(device: Union[Device, torch.device]) -> torch.device: +def prepare_device(device: Device | torch.device) -> torch.device: + _device: torch.device if isinstance(device, Device): if device.gpu_id != -1: - device = torch.device(device.gpu_id) + _device = torch.device(device.gpu_id) else: raise ValueError("Invalid GPU ID provided for the CUDA device provided") elif isinstance(device, torch.device): - device = device + _device = device else: raise ValueError( "Invalid device provided. Supported options: torch.device | torch_tensorrt.Device" ) - return device + return _device -def parse_dynamo_kwargs(kwargs: Dict) -> CompilationSettings: +def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: """Parses the kwargs field of a Dynamo backend Args: @@ -160,3 +165,36 @@ def parse_dynamo_kwargs(kwargs: Dict) -> CompilationSettings: logger.debug(f"Compiling with Settings:\n{settings}") return settings + + +def req_torch_version(min_torch_version: str = "2.dev") -> Callable[..., Any]: + """ + Create a decorator which verifies the Torch version installed + against a specified version range + + Args: + min_torch_version (str): The minimum required Torch version + for the decorated function to work properly + + Returns: + A decorator which raises a descriptive error message if + an unsupported Torch version is used + """ + + def nested_decorator(f: Callable[..., Any]) -> Callable[..., Any]: + def function_wrapper(*args: Any, **kwargs: Any) -> Any: + # Parse minimum and current Torch versions + min_version = version.parse(min_torch_version) + current_version = version.parse(torch.__version__) + + if current_version < min_version: + raise AssertionError( + f"Expected Torch version {min_torch_version} or greater, " + + f"when calling {f}. Detected version {torch.__version__}" + ) + else: + return f(*args, **kwargs) + + return function_wrapper + + return nested_decorator diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 9532c7072c..1765077930 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -899,7 +899,6 @@ def acc_ops_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.relu( network, target, @@ -917,7 +916,6 @@ def acc_ops_leaky_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.leaky_relu( network, target, SourceIR.ACC, name, kwargs["input"], kwargs["negative_slope"] ) @@ -931,7 +929,6 @@ def acc_ops_elu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.elu( network, target, @@ -950,7 +947,6 @@ def acc_ops_selu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.selu( network, target, @@ -2494,7 +2490,6 @@ def acc_ops_where( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - condition_t = kwargs["condition"] x_t = kwargs["x"] y_t = kwargs["y"] @@ -3081,7 +3076,6 @@ def acc_ops_sigmoid( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.sigmoid( network, target, @@ -3473,7 +3467,6 @@ def acc_ops_hardtanh( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.hardtanh( network, target, diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index cf2101ef1a..17c19eda33 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -24,6 +24,7 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) + ## converter list in alphabetic order @tensorrt_converter(torch.ops.aten.add.Tensor) def aten_ops_add( @@ -199,7 +200,6 @@ def aten_ops_elu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if len(args) > 2: return activation.selu( network, @@ -257,7 +257,6 @@ def aten_ops_hardtanh( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.hardtanh( network, target, SourceIR.ATEN, name, args[0], args[1], args[2] ) @@ -286,7 +285,6 @@ def aten_ops_leaky_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.leaky_relu(network, target, SourceIR.ATEN, name, args[0], args[1]) @@ -380,7 +378,6 @@ def aten_ops_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.relu( network, target, @@ -450,7 +447,6 @@ def aten_ops_tanh( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.tanh( network, target, @@ -601,7 +597,6 @@ def aten_ops_sigmoid( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return activation.sigmoid( network, target, diff --git a/py/torch_tensorrt/fx/converters/impl/convolution.py b/py/torch_tensorrt/fx/converters/impl/convolution.py index a0e7537fde..84071ed2d4 100644 --- a/py/torch_tensorrt/fx/converters/impl/convolution.py +++ b/py/torch_tensorrt/fx/converters/impl/convolution.py @@ -40,7 +40,6 @@ def convNd( scale: Optional[Union[torch.Tensor, float]] = None, zero_point: Optional[Union[torch.Tensor, float]] = None, ) -> TRTTensor: - if has_dynamic_shape(input_val.shape): assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution." diff --git a/py/torch_tensorrt/fx/diagnostics.py b/py/torch_tensorrt/fx/diagnostics.py index 0d78513a81..a0c049bff1 100644 --- a/py/torch_tensorrt/fx/diagnostics.py +++ b/py/torch_tensorrt/fx/diagnostics.py @@ -81,7 +81,6 @@ def set_current_collector(collector: "DiagnosticsCollector"): class DiagnosticsWriter: - # the root dir in which the diagnostics will be written _root_dir: str diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass.py b/py/torch_tensorrt/fx/passes/lower_basic_pass.py index e98a9371c5..b203bc82e0 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass.py @@ -39,7 +39,6 @@ def replace_mutable_op(module: torch.fx.GraphModule) -> torch.fx.GraphModule: # only through this op if set(n.args[0].users.keys()) == {n}: with module.graph.inserting_after(n): - # TODO: move this outside? def fill_with_mul_zero_and_add(*args): return args[0].mul(0.0).add(args[1]) diff --git a/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py b/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py index cd5548b30a..8cb79f4958 100644 --- a/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py +++ b/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py @@ -11,7 +11,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase from torch_tensorrt.fx.passes.lower_basic_pass import fix_reshape_batch_dim from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer -from torch_tensorrt._util import sanitized_torch_version +from torch_tensorrt._utils import sanitized_torch_version _LOGGER = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/fx/test/passes/test_pass_utils.py b/py/torch_tensorrt/fx/test/passes/test_pass_utils.py index 6f5edde004..6ea1043203 100644 --- a/py/torch_tensorrt/fx/test/passes/test_pass_utils.py +++ b/py/torch_tensorrt/fx/test/passes/test_pass_utils.py @@ -70,7 +70,6 @@ def model_transform_pass_bad(model, input): input = gen_input(bs=10) with diagnostics.collect_when(diagnostics.CollectionConditions.always()): - with override_alternative_batch_size_exception_should_throw(True): # This should succeed: the validate_inference decorator will # run both bs=10 and bs=1 successfully diff --git a/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py b/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py index 2dc1c404ee..2cc97c46be 100644 --- a/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py +++ b/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py @@ -8,7 +8,7 @@ import torch.nn as nn import torch_tensorrt.fx.passes.remove_duplicate_output_args as dedup -from torch_tensorrt._util import sanitized_torch_version +from torch_tensorrt._utils import sanitized_torch_version from torch.testing._internal.common_utils import run_tests, TestCase diff --git a/py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py b/py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py index 3ce3b7ade8..02a9f57a93 100644 --- a/py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py +++ b/py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py @@ -96,7 +96,6 @@ def test_write_without_collect(self): assert not res # root dir should be empty def test_conditions(self): - _test_cond( diag.CollectionConditions.when_called_by_function( self.test_conditions.__name__ diff --git a/py/torch_tensorrt/fx/test/trt_lower/test_observer_gpu.py b/py/torch_tensorrt/fx/test/trt_lower/test_observer_gpu.py index 9ed3b9df06..85d3ee278b 100644 --- a/py/torch_tensorrt/fx/test/trt_lower/test_observer_gpu.py +++ b/py/torch_tensorrt/fx/test/trt_lower/test_observer_gpu.py @@ -32,7 +32,6 @@ def forward(self, x, y): mod(*inp) with execution_verifier() as verify_execution: - lowerer = lower.Lowerer.create( lower_setting=LowerSetting(min_acc_module_size=0) ) diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index 1ed25d66f1..deac84eeb0 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -2138,7 +2138,6 @@ def linalg_norm(*, input, ord, dim, keepdim): ], ) def norm_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node: - input_node = node.kwargs["input"] p = node.kwargs["p"] dim = node.kwargs["dim"] diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py index bc8c613fee..9d5576bd63 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py @@ -59,7 +59,6 @@ def __init__(self): def rewrite( self, fn: FunctionType ) -> Tuple[FunctionType, Set[Type[Exception]], Set[Type[Exception]]]: - # Normalize the source lines sourcelines, _ = inspect.getsourcelines(fn) sourcelines = normalize_source_lines(sourcelines) diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py index 248ec3f920..7a2d9a2fc9 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union from packaging import version -from torch_tensorrt._util import sanitized_torch_version +from torch_tensorrt._utils import sanitized_torch_version import torch @@ -47,7 +47,6 @@ def __init__( specialize_int: bool = True, verbose: bool = True, ) -> None: - self.capture_scalar_outputs = capture_scalar_outputs self.guard_nn_modules = guard_nn_modules self.dynamic_shapes = dynamic_shapes diff --git a/py/torch_tensorrt/fx/utils.py b/py/torch_tensorrt/fx/utils.py index 859529d861..4202e1e96b 100644 --- a/py/torch_tensorrt/fx/utils.py +++ b/py/torch_tensorrt/fx/utils.py @@ -12,7 +12,7 @@ replace_op_with_indices, run_const_fold, ) -from torch_tensorrt._util import sanitized_torch_version +from torch_tensorrt._utils import sanitized_torch_version from .types import Shape, TRTDataType @@ -127,7 +127,6 @@ def get_dynamic_dims(shape: Shape) -> List[int]: def proxytensor_trace(mod, inputs): - mod.eval() def f(*inp): diff --git a/py/torch_tensorrt/logging.py b/py/torch_tensorrt/logging.py index 922992eb19..e48e3c6317 100644 --- a/py/torch_tensorrt/logging.py +++ b/py/torch_tensorrt/logging.py @@ -1,13 +1,15 @@ from enum import Enum +from typing import Any + from torch_tensorrt._C import ( + LogLevel, + _get_is_colored_output_on, _get_logging_prefix, - _set_logging_prefix, _get_reportable_log_level, - _set_reportable_log_level, - _get_is_colored_output_on, - _set_is_colored_output_on, _log, - LogLevel, + _set_is_colored_output_on, + _set_logging_prefix, + _set_reportable_log_level, ) @@ -22,19 +24,21 @@ class Level(Enum): Graph = LogLevel.GRAPH @staticmethod - def _to_internal_level(external) -> LogLevel: + def _to_internal_level(external: "Level") -> LogLevel: if external == Level.InternalError: return LogLevel.INTERNAL_ERROR - if external == Level.Error: + elif external == Level.Error: return LogLevel.ERROR - if external == Level.Warning: + elif external == Level.Warning: return LogLevel.WARNING - if external == Level.Info: + elif external == Level.Info: return LogLevel.INFO - if external == Level.Debug: + elif external == Level.Debug: return LogLevel.DEBUG - if external == Level.Graph: + elif external == Level.Graph: return LogLevel.GRAPH + else: + raise ValueError("Unknown log severity") def get_logging_prefix() -> str: @@ -43,10 +47,10 @@ def get_logging_prefix() -> str: Returns: str: Prefix used for logger """ - return _get_logging_prefix() + return str(_get_logging_prefix()) -def set_logging_prefix(prefix: str): +def set_logging_prefix(prefix: str) -> None: """Set the prefix used when logging messages Args: @@ -64,7 +68,7 @@ def get_reportable_log_level() -> Level: return Level(_get_reportable_log_level()) -def set_reportable_log_level(level: Level): +def set_reportable_log_level(level: Level) -> None: """Set the level required for a message to be printed to the log Args: @@ -79,10 +83,10 @@ def get_is_colored_output_on() -> bool: Returns: bool: If colored output is one """ - return _get_is_colored_output_on() + return bool(_get_is_colored_output_on()) -def set_is_colored_output_on(colored_output_on: bool): +def set_is_colored_output_on(colored_output_on: bool) -> None: """Enable or disable color in the log output Args: @@ -91,7 +95,7 @@ def set_is_colored_output_on(colored_output_on: bool): _set_is_colored_output_on(colored_output_on) -def log(level: Level, msg: str): +def log(level: Level, msg: str) -> None: """Add a new message to the log Adds a new message to the log at a specified level. The message @@ -120,11 +124,11 @@ class internal_errors: outputs = model_torchtrt(inputs) """ - def __enter__(self): + def __enter__(self) -> None: self.external_lvl = get_reportable_log_level() set_reportable_log_level(Level.InternalError) - def __exit__(self, exc_type, exc_value, exc_tb): + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: set_reportable_log_level(self.external_lvl) @@ -137,11 +141,11 @@ class errors: outputs = model_torchtrt(inputs) """ - def __enter__(self): + def __enter__(self) -> None: self.external_lvl = get_reportable_log_level() set_reportable_log_level(Level.Error) - def __exit__(self, exc_type, exc_value, exc_tb): + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: set_reportable_log_level(self.external_lvl) @@ -154,11 +158,11 @@ class warnings: model_trt = torch_tensorrt.compile(model, **spec) """ - def __enter__(self): + def __enter__(self) -> None: self.external_lvl = get_reportable_log_level() set_reportable_log_level(Level.Warning) - def __exit__(self, exc_type, exc_value, exc_tb): + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: set_reportable_log_level(self.external_lvl) @@ -171,11 +175,11 @@ class info: model_trt = torch_tensorrt.compile(model, **spec) """ - def __enter__(self): + def __enter__(self) -> None: self.external_lvl = get_reportable_log_level() set_reportable_log_level(Level.Info) - def __exit__(self, exc_type, exc_value, exc_tb): + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: set_reportable_log_level(self.external_lvl) @@ -188,11 +192,11 @@ class debug: model_trt = torch_tensorrt.compile(model, **spec) """ - def __enter__(self): + def __enter__(self) -> None: self.external_lvl = get_reportable_log_level() set_reportable_log_level(Level.Debug) - def __exit__(self, exc_type, exc_value, exc_tb): + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: set_reportable_log_level(self.external_lvl) @@ -206,9 +210,9 @@ class graphs: model_trt = torch_tensorrt.compile(model, **spec) """ - def __enter__(self): + def __enter__(self) -> None: self.external_lvl = get_reportable_log_level() set_reportable_log_level(Level.Graph) - def __exit__(self, exc_type, exc_value, exc_tb): + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: set_reportable_log_level(self.external_lvl) diff --git a/py/torch_tensorrt/ptq.py b/py/torch_tensorrt/ptq.py index f60dd74b52..5d13ab9108 100644 --- a/py/torch_tensorrt/ptq.py +++ b/py/torch_tensorrt/ptq.py @@ -1,12 +1,17 @@ -from typing import List, Dict, Any -import torch +import sys +from typing import Any, List, Optional + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + import os +from enum import Enum +import torch from torch_tensorrt import _C -from torch_tensorrt._version import __version__ -from torch_tensorrt.logging import * -from types import FunctionType -from enum import Enum +from torch_tensorrt.logging import Level, log class CalibrationAlgo(Enum): @@ -16,15 +21,15 @@ class CalibrationAlgo(Enum): MINMAX_CALIBRATION = _C.CalibrationAlgo.MINMAX_CALIBRATION -def get_cache_mode_batch(self): +def get_cache_mode_batch(self: object) -> None: return None -def get_batch_size(self): +def get_batch_size(self: object) -> int: return 1 -def get_batch(self, names): +def get_batch(self: object, _: Any) -> Optional[List[int]]: if self.current_batch_idx + self.batch_size > len(self.data_loader.dataset): return None @@ -39,27 +44,30 @@ def get_batch(self, names): return inputs_gpu -def read_calibration_cache(self): +def read_calibration_cache(self: object) -> bytes: if self.cache_file and self.use_cache: if os.path.exists(self.cache_file): with open(self.cache_file, "rb") as f: - return f.read() + b: bytes = f.read() + return b + else: + raise FileNotFoundError(self.cache_file) else: return b"" -def write_calibration_cache(self, cache): +def write_calibration_cache(self: object, cache: bytes) -> None: if self.cache_file: with open(self.cache_file, "wb") as f: f.write(cache) else: - return b"" + return # deepcopy (which involves pickling) is performed on the compile_spec internally during compilation. # We register this __reduce__ function for pickler to identity the calibrator object returned by DataLoaderCalibrator during deepcopy. # This should be the object's local name relative to the module https://docs.python.org/3/library/pickle.html#object.__reduce__ -def __reduce__(self): +def __reduce__(self: object) -> str: return self.__class__.__name__ @@ -75,10 +83,10 @@ class DataLoaderCalibrator(object): device: device on which calibration data is copied to. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): pass - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> Self: dataloader = args[0] algo_type = kwargs.get("algo_type", CalibrationAlgo.ENTROPY_CALIBRATION_2) cache_file = kwargs.get("cache_file", None) @@ -126,27 +134,30 @@ def __new__(cls, *args, **kwargs): # Using type metaclass to construct calibrator class based on algorithm type if algo_type == CalibrationAlgo.ENTROPY_CALIBRATION: - return type( + calib_ec: Self = type( "Int8EntropyCalibrator", (_C.IInt8EntropyCalibrator,), attribute_mapping )() + return calib_ec elif algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2: - return type( + calib_ec2: Self = type( "Int8EntropyCalibrator2", (_C.IInt8EntropyCalibrator2,), attribute_mapping, )() + return calib_ec2 elif algo_type == CalibrationAlgo.LEGACY_CALIBRATION: - return type( + calib_lc: Self = type( "Int8LegacyCalibrator", (_C.IInt8LegacyCalibrator,), attribute_mapping )() + return calib_lc elif algo_type == CalibrationAlgo.MINMAX_CALIBRATION: - return type( + calib_mmc: Self = type( "Int8MinMaxCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping )() + return calib_mmc else: - log( - Level.Error, - "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION", + raise ValueError( + "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION" ) @@ -158,10 +169,10 @@ class CacheCalibrator(object): algo_type: choice of calibration algorithm. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): pass - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> Self: cache_file = args[0] algo_type = kwargs.get("algo_type", CalibrationAlgo.ENTROPY_CALIBRATION_2) @@ -184,23 +195,26 @@ def __new__(cls, *args, **kwargs): } # Using type metaclass to construct calibrator class based on algorithm type if algo_type == CalibrationAlgo.ENTROPY_CALIBRATION: - return type( + calib_ec: Self = type( "DataLoaderCalibrator", (_C.IInt8EntropyCalibrator,), attribute_mapping )() + return calib_ec elif algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2: - return type( + calib_ec2: Self = type( "DataLoaderCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping )() + return calib_ec2 elif algo_type == CalibrationAlgo.LEGACY_CALIBRATION: - return type( + calib_lc: Self = type( "DataLoaderCalibrator", (_C.IInt8LegacyCalibrator,), attribute_mapping )() + return calib_lc elif algo_type == CalibrationAlgo.MINMAX_CALIBRATION: - return type( + calib_mmc: Self = type( "DataLoaderCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping )() + return calib_mmc else: - log( - Level.Error, - "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION", + raise ValueError( + "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION" ) diff --git a/py/torch_tensorrt/ts/ts_input.py b/py/torch_tensorrt/ts/_Input.py similarity index 70% rename from py/torch_tensorrt/ts/ts_input.py rename to py/torch_tensorrt/ts/_Input.py index 00055d4f13..f9cbf2c333 100644 --- a/py/torch_tensorrt/ts/ts_input.py +++ b/py/torch_tensorrt/ts/_Input.py @@ -1,15 +1,10 @@ -from enum import Enum -from typing import List, Dict, Any, Tuple, Optional +from typing import Any -import torch - -from torch_tensorrt import _C -from torch_tensorrt import _enums -from torch_tensorrt import _Input +from torch_tensorrt import _C, _enums from torch_tensorrt._Input import Input -class TSInput(Input): +class TorchScriptInput(Input): """ Defines an input to a module in terms of expected shape, data type and tensor format. @@ -26,7 +21,7 @@ class TSInput(Input): format (torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW) """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """__init__ Method for torch_tensorrt.Input Input accepts one of a few construction patterns @@ -52,38 +47,39 @@ def __init__(self, *args, **kwargs): - Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW) - Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=torch_tensorrt.dtype.float32, format=torch_tensorrt.TensorFormat.NCHW """ - super(TSInput, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) def _to_internal(self) -> _C.Input: internal_in = _C.Input() if self.shape_mode == Input._ShapeMode.DYNAMIC: - if not Input._supported_input_size_type(self.shape["min_shape"]): - raise TypeError( - "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " - + str(type(self.shape["min_shape"])) - + " for min_shape" - ) - else: - internal_in.min = self.shape["min_shape"] - - if not Input._supported_input_size_type(self.shape["opt_shape"]): - raise TypeError( - "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " - + str(type(self.shape["opt_shape"])) - + " for opt_shape" - ) - else: - internal_in.opt = self.shape["opt_shape"] - - if not Input._supported_input_size_type(self.shape["max_shape"]): - raise TypeError( - "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " - + str(type(self.shape["max_shape"])) - + " for max_shape" - ) - else: - internal_in.max = self.shape["max_shape"] - internal_in.input_is_dynamic = True + if isinstance(self.shape, dict): + if not Input._supported_input_size_type(self.shape["min_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape["min_shape"])) + + " for min_shape" + ) + else: + internal_in.min = self.shape["min_shape"] + + if not Input._supported_input_size_type(self.shape["opt_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape["opt_shape"])) + + " for opt_shape" + ) + else: + internal_in.opt = self.shape["opt_shape"] + + if not Input._supported_input_size_type(self.shape["max_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape["max_shape"])) + + " for max_shape" + ) + else: + internal_in.max = self.shape["max_shape"] + internal_in.input_is_dynamic = True else: if not Input._supported_input_size_type(self.shape): raise TypeError( diff --git a/py/torch_tensorrt/ts/__init__.py b/py/torch_tensorrt/ts/__init__.py index 47ef249e55..5cb45cba5c 100644 --- a/py/torch_tensorrt/ts/__init__.py +++ b/py/torch_tensorrt/ts/__init__.py @@ -1,3 +1,3 @@ -from torch_tensorrt.ts._compiler import * -from torch_tensorrt.ts._compile_spec import TensorRTCompileSpec -from torch_tensorrt.ts.ts_input import TSInput +from torch_tensorrt.ts._compile_spec import TensorRTCompileSpec # noqa: F401 +from torch_tensorrt.ts._compiler import * # noqa: F403 +from torch_tensorrt.ts._Input import TorchScriptInput # noqa: F401 diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 08f18a22dd..6803259985 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -1,15 +1,14 @@ -from typing import List, Dict, Any, Set +from copy import deepcopy +from typing import Any, Dict, List, Optional, Set + import torch -from torch_tensorrt import _C import torch_tensorrt._C.ts as _ts_C -from torch_tensorrt import _enums -from torch_tensorrt._Input import Input +from torch_tensorrt import _C, _enums from torch_tensorrt._Device import Device +from torch_tensorrt._Input import Input from torch_tensorrt.logging import Level, log -from typing import Tuple, List, Dict -import warnings -from copy import deepcopy -from torch_tensorrt.ts.ts_input import TSInput +from torch_tensorrt.ts._Input import TorchScriptInput + import tensorrt as trt @@ -40,7 +39,7 @@ def _supported_input_size_type(input_size: Any) -> bool: ) -def _parse_op_precision(precision: Any) -> _enums.dtype: +def _parse_op_precision(precision: Any) -> _enums.dtype: # type: ignore[name-defined] if isinstance(precision, torch.dtype): if precision == torch.int8: return _enums.dtype.int8 @@ -64,9 +63,9 @@ def _parse_op_precision(precision: Any) -> _enums.dtype: ) -def _parse_enabled_precisions(precisions: Any) -> Set: +def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]: # type: ignore[name-defined] parsed_precisions = set() - if any([isinstance(precisions, type) for type in [list, tuple, set]]): + if any(isinstance(precisions, type) for type in [list, tuple, set]): for p in precisions: parsed_precisions.add(_parse_op_precision(p)) else: @@ -74,7 +73,7 @@ def _parse_enabled_precisions(precisions: Any) -> Set: return parsed_precisions -def _parse_device_type(device: Any) -> _enums.DeviceType: +def _parse_device_type(device: Any) -> _enums.DeviceType: # type: ignore[name-defined] if isinstance(device, torch.device): if device.type == "cuda": return _C.DeviceType.gpu @@ -159,7 +158,7 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> _ts_C.TorchFallback: return info -def _parse_input_signature(input_signature: Any, depth: int = 0): +def _parse_input_signature(input_signature: Any, depth: int = 0) -> Any: if depth > 2: raise AssertionError( "Input nesting depth exceeds max supported depth, use 1 level: [A, B], or 2 level: [A, (B, C)]" @@ -177,9 +176,7 @@ def _parse_input_signature(input_signature: Any, depth: int = 0): input = _parse_input_signature(item, depth + 1) input_list.append(input) return input_list - elif isinstance(input_signature, Input) or isinstance( - input_signature, torch.Tensor - ): + elif isinstance(input_signature, (Input, torch.Tensor)): i = ( Input.from_tensor(input_signature) if isinstance(input_signature, torch.Tensor) @@ -195,15 +192,20 @@ def _parse_input_signature(input_signature: Any, depth: int = 0): ts_i = i if i.shape_mode == Input._ShapeMode.STATIC: - ts_i = TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + ts_i = TorchScriptInput(shape=i.shape, dtype=i.dtype, format=i.format) elif i.shape_mode == Input._ShapeMode.DYNAMIC: - ts_i = TSInput( - min_shape=i.shape["min_shape"], - opt_shape=i.shape["opt_shape"], - max_shape=i.shape["max_shape"], - dtype=i.dtype, - format=i.format, - ) + if isinstance(i.shape, dict): + ts_i = TorchScriptInput( + min_shape=i.shape["min_shape"], + opt_shape=i.shape["opt_shape"], + max_shape=i.shape["max_shape"], + dtype=i.dtype, + format=i.format, + ) + else: + raise ValueError( + f"Input set as dynamic, expected dictionary of shapes but found {i.shape}" + ) else: raise ValueError( "Invalid shape mode detected for input while parsing the input_signature" @@ -226,10 +228,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: if len(compile_spec["inputs"]) > 0: if not all( - [ - isinstance(i, torch.Tensor) or isinstance(i, Input) - for i in compile_spec["inputs"] - ] + isinstance(i, (torch.Tensor, Input)) for i in compile_spec["inputs"] ): raise KeyError( "Input specs should be either torch_tensorrt.Input or torch.Tensor, found types: {}".format( @@ -245,13 +244,13 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: for i in inputs: if i.shape_mode == Input._ShapeMode.STATIC: ts_inputs.append( - TSInput( + TorchScriptInput( shape=i.shape, dtype=i.dtype, format=i.format )._to_internal() ) elif i.shape_mode == Input._ShapeMode.DYNAMIC: ts_inputs.append( - TSInput( + TorchScriptInput( min_shape=i.shape["min_shape"], opt_shape=i.shape["opt_shape"], max_shape=i.shape["max_shape"], @@ -342,23 +341,23 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: def TensorRTCompileSpec( - inputs=[], - input_signature=None, - device=Device._current_device(), - disable_tf32=False, - sparse_weights=False, - enabled_precisions=set(), - refit=False, - debug=False, - capability=_enums.EngineCapability.default, - num_avg_timing_iters=1, - workspace_size=0, - dla_sram_size=1048576, - dla_local_dram_size=1073741824, - dla_global_dram_size=536870912, - truncate_long_and_double=False, - calibrator=None, - allow_shape_tensors=False, + inputs: Optional[List[torch.Tensor | Input]] = None, + input_signature: Optional[Any] = None, + device: torch.device | Device = Device._current_device(), + disable_tf32: bool = False, + sparse_weights: bool = False, + enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, # type: ignore[name-defined] + refit: bool = False, + debug: bool = False, + capability: _enums.EngineCapability = _enums.EngineCapability.default, # type: ignore[name-defined] + num_avg_timing_iters: int = 1, + workspace_size: int = 0, + dla_sram_size: int = 1048576, + dla_local_dram_size: int = 1073741824, + dla_global_dram_size: int = 536870912, + truncate_long_and_double: bool = False, + calibrator: object = None, + allow_shape_tensors: bool = False, ) -> torch.classes.tensorrt.CompileSpec: """Utility to create a formated spec dictionary for using the PyTorch TensorRT backend @@ -400,12 +399,14 @@ def TensorRTCompileSpec( """ compile_spec = { - "inputs": inputs, + "inputs": inputs if inputs is not None else [], # "input_signature": input_signature, "device": device, "disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas "sparse_weights": sparse_weights, # Enable sparsity for convolution and fully connected layers. - "enabled_precisions": enabled_precisions, # Enabling FP16 kernels + "enabled_precisions": enabled_precisions + if enabled_precisions is not None + else set(), # Enabling FP16 kernels "refit": refit, # enable refit "debug": debug, # enable debuggable engine "capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels diff --git a/py/torch_tensorrt/ts/_compiler.py b/py/torch_tensorrt/ts/_compiler.py index 9dc0731014..30828ce5d8 100644 --- a/py/torch_tensorrt/ts/_compiler.py +++ b/py/torch_tensorrt/ts/_compiler.py @@ -1,37 +1,36 @@ -from typing import List, Dict, Any -import torch -from torch import nn +from typing import Any, List, Optional, Sequence, Set, Tuple +import torch import torch_tensorrt._C.ts as _C from torch_tensorrt import _enums -from torch_tensorrt.ts._compile_spec import _parse_compile_spec, _parse_device from torch_tensorrt._Device import Device -from types import FunctionType +from torch_tensorrt._Input import Input +from torch_tensorrt.ts._compile_spec import _parse_compile_spec, _parse_device def compile( module: torch.jit.ScriptModule, - inputs=[], - input_signature=None, - device=Device._current_device(), - disable_tf32=False, - sparse_weights=False, - enabled_precisions=set(), - refit=False, - debug=False, - capability=_enums.EngineCapability.default, - num_avg_timing_iters=1, - workspace_size=0, - dla_sram_size=1048576, - dla_local_dram_size=1073741824, - dla_global_dram_size=536870912, - calibrator=None, - truncate_long_and_double=False, - require_full_compilation=False, - min_block_size=3, - torch_executed_ops=[], - torch_executed_modules=[], - allow_shape_tensors=False, + inputs: Optional[Sequence[Input | torch.Tensor]] = None, + input_signature: Optional[Tuple[Input | torch.Tensor | Sequence[Any]]] = None, + device: Device = Device._current_device(), + disable_tf32: bool = False, + sparse_weights: bool = False, + enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, + refit: bool = False, + debug: bool = False, + capability: _enums.EngineCapability = _enums.EngineCapability.default, + num_avg_timing_iters: int = 1, + workspace_size: int = 0, + dla_sram_size: int = 1048576, + dla_local_dram_size: int = 1073741824, + dla_global_dram_size: int = 536870912, + calibrator: object = None, + truncate_long_and_double: bool = False, + require_full_compilation: bool = False, + min_block_size: int = 3, + torch_executed_ops: Optional[List[str]] = None, + torch_executed_modules: Optional[List[str]] = None, + allow_shape_tensors: bool = False, ) -> torch.jit.ScriptModule: """Compile a TorchScript module for NVIDIA GPUs using TensorRT @@ -101,25 +100,36 @@ def compile( torch.jit.ScriptModule: Compiled TorchScript Module, when run it will execute via TensorRT """ + input_list = list(inputs) if inputs is not None else [] + enabled_precisions_set = ( + enabled_precisions if enabled_precisions is not None else set() + ) + torch_executed_module_list = ( + torch_executed_modules if torch_executed_modules is not None else [] + ) + torch_executed_op_list = ( + torch_executed_ops if torch_executed_ops is not None else [] + ) + if isinstance(module, torch.jit.ScriptFunction): raise TypeError( "torch.jit.ScriptFunction currently is not directly supported, wrap the function in a module to compile" ) if require_full_compilation and ( - len(torch_executed_modules) > 0 or len(torch_executed_ops) > 0 + len(torch_executed_module_list) > 0 or len(torch_executed_op_list) > 0 ): raise ValueError( f"require_full_compilation is enabled however the list of modules and ops to run in torch is not empty. Found: torch_executed_ops: {torch_executed_ops}, torch_executed_modules: {torch_executed_modules}" ) spec = { - "inputs": inputs, + "inputs": input_list, "input_signature": input_signature, "device": device, "disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format "sparse_weights": sparse_weights, # Enable sparsity for convolution and fully connected layers. - "enabled_precisions": enabled_precisions, # Enabling FP16 kernels + "enabled_precisions": enabled_precisions_set, # Enabling FP16 kernels "refit": refit, # enable refit "debug": debug, # enable debuggable engine "capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels @@ -129,38 +139,40 @@ def compile( "truncate_long_and_double": truncate_long_and_double, "torch_fallback": { "enabled": not require_full_compilation, - "forced_fallback_ops": torch_executed_ops, - "forced_fallback_modules": torch_executed_modules, + "forced_fallback_ops": torch_executed_op_list, + "forced_fallback_modules": torch_executed_module_list, "min_block_size": min_block_size, }, "allow_shape_tensors": allow_shape_tensors, } compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec)) - compiled_module = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod) + compiled_module: torch.jit.ScriptModule = torch.jit._recursive.wrap_cpp_module( + compiled_cpp_mod + ) return compiled_module def convert_method_to_trt_engine( module: torch.jit.ScriptModule, method_name: str = "forward", - inputs=[], - device=Device._current_device(), - disable_tf32=False, - sparse_weights=False, - enabled_precisions=set(), - refit=False, - debug=False, - capability=_enums.EngineCapability.default, - num_avg_timing_iters=1, - workspace_size=0, - dla_sram_size=1048576, - dla_local_dram_size=1073741824, - dla_global_dram_size=536870912, - truncate_long_and_double=False, - calibrator=None, - allow_shape_tensors=False, -) -> bytearray: + inputs: Optional[Sequence[Input | torch.Tensor]] = None, + device: Device = Device._current_device(), + disable_tf32: bool = False, + sparse_weights: bool = False, + enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, + refit: bool = False, + debug: bool = False, + capability: _enums.EngineCapability = _enums.EngineCapability.default, + num_avg_timing_iters: int = 1, + workspace_size: int = 0, + dla_sram_size: int = 1048576, + dla_local_dram_size: int = 1073741824, + dla_global_dram_size: int = 536870912, + truncate_long_and_double: int = False, + calibrator: object = None, + allow_shape_tensors: bool = False, +) -> bytes: """Convert a TorchScript module method to a serialized TensorRT engine Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings @@ -168,7 +180,6 @@ def convert_method_to_trt_engine( Arguments: module (torch.jit.ScriptModule): Source module, a result of tracing or scripting a PyTorch ``torch.nn.Module`` - method_name (str): Name of method to convert Keyword Args: inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using @@ -187,6 +198,7 @@ def convert_method_to_trt_engine( torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ] + method_name (str): Name of method to convert input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** :: @@ -221,19 +233,24 @@ def convert_method_to_trt_engine( allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT Returns: - bytearray: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs + bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ + input_list = list(inputs) if inputs is not None else [] + enabled_precisions_set = ( + enabled_precisions if enabled_precisions is not None else {torch.float} + ) + if isinstance(module, torch.jit.ScriptFunction): raise TypeError( "torch.jit.ScriptFunctions currently are not directly supported, wrap the function in a module to compile" ) compile_spec = { - "inputs": inputs, + "inputs": input_list, "device": device, "disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas "sparse_weights": sparse_weights, # Enable sparsity for convolution and fully connected layers. - "enabled_precisions": enabled_precisions, # Enabling FP16 kernels + "enabled_precisions": enabled_precisions_set, # Enabling FP16 kernels "refit": refit, # enable refit "debug": debug, # enable debuggable engine "capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels @@ -259,9 +276,9 @@ def convert_method_to_trt_engine( def embed_engine_in_new_module( serialized_engine: bytes, + input_binding_names: Optional[List[str]] = None, + output_binding_names: Optional[List[str]] = None, device: Device = Device._current_device(), - input_binding_names: List[str] = [], - output_binding_names: List[str] = [], ) -> torch.jit.ScriptModule: """Takes a pre-built serialized TensorRT engine and embeds it within a TorchScript module @@ -278,22 +295,29 @@ def embed_engine_in_new_module( Module can be save with engine embedded with torch.jit.save and moved / loaded according to torch_tensorrt portability rules Arguments: - serialized_engine (bytes): Serialized TensorRT engine from either torch_tensorrt or TensorRT APIs + serialized_engine (bytearray): Serialized TensorRT engine from either torch_tensorrt or TensorRT APIs Keyword Arguments: - device (Union(torch_tensorrt.Device, torch.device, dict)): Target device to run engine on. Must be compatible with engine provided. Default: Current active device input_binding_names (List[str]): List of names of TensorRT bindings in order to be passed to the encompassing PyTorch module output_binding_names (List[str]): List of names of TensorRT bindings in order that should be returned from the encompassing PyTorch module + device (Union(torch_tensorrt.Device, torch.device, dict)): Target device to run engine on. Must be compatible with engine provided. Default: Current active device Returns: torch.jit.ScriptModule: New TorchScript module with engine embedded """ + input_binding_name_list = ( + input_binding_names if input_binding_names is not None else [] + ) + output_binding_name_list = ( + output_binding_names if output_binding_names is not None else [] + ) cpp_mod = _C.embed_engine_in_new_module( serialized_engine, _parse_device(device), - input_binding_names, - output_binding_names, + input_binding_name_list, + output_binding_name_list, ) - return torch.jit._recursive.wrap_cpp_module(cpp_mod) + wrapped_mod: torch.jit.ScriptModule = torch.jit._recursive.wrap_cpp_module(cpp_mod) + return wrapped_mod def check_method_op_support( @@ -312,4 +336,5 @@ def check_method_op_support( Returns: bool: True if supported Method """ - return _C.check_method_op_support(module._c, method_name) + supported: bool = _C.check_method_op_support(module._c, method_name) + return supported diff --git a/pyproject.toml b/pyproject.toml index ae130dc08a..3f99213db9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,20 +60,161 @@ Documentation = "https://pytorch.org/tensorrt" Repository = "https://github.com/pytorch/tensorrt.git" Changelog = "https://github.com/pytorch/tensorrt/releases" +[tool.setuptools] +package-dir = {"" = "py"} +include-package-data = false + +[tool.ruff] +# NOTE: Synchoronize the ignores with .flake8 +ignore = [ + # these ignores are from flake8-bugbear; please fix! + "B007", "B008", "B017", + "B018", # Useless expression + "B019", "B020", + "B023", "B024", "B026", + "B028", # No explicit `stacklevel` keyword argument found + "B904", "B905", + "E402", + "C408", # C408 ignored because we like the dict keyword argument syntax + "E501", # E501 is not flexible enough, we're using B950 instead + "E721", + "E731", # Assign lambda expression + "E741", + "EXE001", + "F405", + "F821", + "F841", + # these ignores are from flake8-logging-format; please fix! + "G101", "G201", "G202", "G003", "G004", + # these ignores are from RUFF perf; please fix! + "PERF203", "PERF4", + "SIM102", "SIM103", "SIM112", # flake8-simplify code styles + "SIM105", # these ignores are from flake8-simplify. please fix or ignore with commented reason + "SIM108", + "SIM110", + "SIM114", # Combine `if` branches using logical `or` operator + "SIM115", + "SIM116", # Disable Use a dictionary instead of consecutive `if` statements + "SIM117", + "SIM118", +] +#line-length = 120 +select = [ + "B", + "C4", + "G", + "E", + "F", + "SIM1", + "W", + # Not included in flake8 + "PERF", + "PLE", + "TRY302", +] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +target-version = "py311" + +# Allow autofix for all enabled rules (when `--fix`) is provided. +fixable = [ + "A","B","C","D","E","F","G", + "I","N","Q","S","T","W", + "ANN", "ARG", "BLE", "COM", "DJ", + "DTZ", "EM", "ERA", "EXE", "FBT", + "ICN", "INP", "ISC", "NPY", "PD", + "PGH", "PIE", "PL", "PT", "PTH", + "PYI", "RET", "RSE", "RUF", "SIM", + "SLF", "TCH", "TID", "TRY", "UP", "YTT"] +unfixable = [] + +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "env", + "py/torch_tensorrt/fx", + ".github", + "examples", + "tests", + "tools", + "docs", + "docsrc", + "tests", + "setup.py", + "noxfile.py", + "__init__.py" +] + +[tool.ruff.mccabe] +# Unlike Flake8, default to a complexity level of 10. +max-complexity = 10 + +[tool.isort] +profile = "black" +py_version = 311 +skip = [ + "py/torch_tensorrt/fx", +] [tool.black] -# Uncomment if pyproject.toml worked fine to ensure consistency with flake8 -# line-length = 120 +#line-length = 120 target-versions = ["py38", "py39", "py310", "py311", "py312"] force-exclude = """ elu_converter/setup.py """ [tool.mypy] +strict = true +ignore_missing_imports = true show_error_codes = true disable_error_code = "attr-defined" no_implicit_optional = true +exclude = [ + "^py/torch_tensorrt/fx", + "py/torch_tensorrt/fx", + "torch_tensorrt/fx", + "py/torch_tensorrt/_C.so", + "examples", + "docs", + "docsrc", + "tests", + "setup.py", + "noxfile.py" +] +python_version = "3.11" -[tool.setuptools] -package-dir = {"" = "py"} -include-package-data = false \ No newline at end of file +follow_imports = "skip" + +[[tool.mypy.overrides]] +module = "torch_tensorrt.dynamo.conversion.aten_converters" +disable_error_code = "arg-type" + +[[tool.mypy.overrides]] +module = "torch_tensorrt.dynamo.lowering._decompositions" +disallow_untyped_calls = false + +[[tool.mypy.overrides]] +module = "torch_tensorrt.fx.*" +ignore_errors = true +follow_imports = "skip" \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index f6e126f207..f6160b0130 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,11 @@ -pre-commit==2.20.0 -black==22.6.0 +pre-commit>=2.20.0 +black>=22.6.0 clang-format==14.0.6 +mypy +isort +ruff +pytest +transformers +timm +parameterized +expecttest diff --git a/setup.py b/setup.py index 7544ea8eb5..f9329cce7e 100644 --- a/setup.py +++ b/setup.py @@ -1,23 +1,22 @@ +import glob import os +import platform +import subprocess import sys -import glob -import yaml +import warnings from dataclasses import dataclass +from distutils.cmd import Command +from shutil import copyfile, rmtree + import setuptools -from setuptools import setup, Extension, find_namespace_packages +import yaml +from setuptools import Extension, find_namespace_packages, setup from setuptools.command.build_ext import build_ext from setuptools.command.develop import develop -from setuptools.command.install import install from setuptools.command.editable_wheel import editable_wheel -from distutils.cmd import Command -from wheel.bdist_wheel import bdist_wheel - +from setuptools.command.install import install from torch.utils import cpp_extension -from shutil import copyfile, rmtree - -import subprocess -import platform -import warnings +from wheel.bdist_wheel import bdist_wheel dir_path = os.path.dirname(os.path.realpath(__file__)) + "/py" diff --git a/tests/core/BUILD b/tests/core/BUILD index 820974397e..19dedd4fe5 100644 --- a/tests/core/BUILD +++ b/tests/core/BUILD @@ -23,7 +23,7 @@ cc_test( "@googletest//:gtest_main", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index 30f1fd8e5a..658d66fc46 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -23,7 +23,7 @@ cc_test( "@googletest//:gtest_main", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) diff --git a/tests/core/partitioning/BUILD b/tests/core/partitioning/BUILD index 5aba817bd6..974a9b0e4d 100644 --- a/tests/core/partitioning/BUILD +++ b/tests/core/partitioning/BUILD @@ -55,7 +55,7 @@ cc_test( "@googletest//:gtest_main", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) @@ -70,7 +70,7 @@ cc_test( "@googletest//:gtest_main", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) @@ -85,7 +85,7 @@ cc_test( "@googletest//:gtest_main", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) @@ -100,7 +100,7 @@ cc_test( "@googletest//:gtest_main", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) diff --git a/tests/cpp/BUILD b/tests/cpp/BUILD index 709187e1b2..783bf2895e 100644 --- a/tests/cpp/BUILD +++ b/tests/cpp/BUILD @@ -87,7 +87,7 @@ cc_test( "@googletest//:gtest_main", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) @@ -125,7 +125,7 @@ cc_test( "@googletest//:gtest_main", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) @@ -140,7 +140,7 @@ cc_test( "@googletest//:gtest_main", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) @@ -152,7 +152,7 @@ cc_test( "@googletest//:gtest_main", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) @@ -167,7 +167,7 @@ cc_test( "@googletest//:gtest_main", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) @@ -202,6 +202,6 @@ cc_library( "@googletest//:gtest_main", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@libtorch//:libtorch"], + "//conditions:default": ["@libtorch"], }), ) diff --git a/tests/py/dynamo/backend/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py index fc9ba5a232..9c0cf18d25 100644 --- a/tests/py/dynamo/backend/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -74,7 +74,11 @@ def forward(self, x, y): torch.randint(1, 40, (16, 7, 5), dtype=torch.int).cuda(), ] - (unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing( + ( + unexpected_ops_seen, + _, + partitioned_graphs, + ) = lower_graph_testing( fx_graph, inputs, unexpected_ops=unexpected_ops, @@ -233,7 +237,11 @@ def forward(self, x, y): torch.randint(1, 40, (16, 7, 5), dtype=torch.long).cuda(), ] - (unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing( + ( + unexpected_ops_seen, + _, + partitioned_graphs, + ) = lower_graph_testing( fx_graph, inputs, unexpected_ops=unexpected_ops, diff --git a/tests/py/dynamo/backend/test_partitioning.py b/tests/py/dynamo/backend/test_partitioning.py index 76645075b9..a5d6495754 100644 --- a/tests/py/dynamo/backend/test_partitioning.py +++ b/tests/py/dynamo/backend/test_partitioning.py @@ -93,7 +93,11 @@ def forward(self, x, y): ] fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) - (unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing( + ( + unexpected_ops_seen, + _, + partitioned_graphs, + ) = lower_graph_testing( fx_graph, inputs, unexpected_ops=unexpected_ops, diff --git a/py/torch_tensorrt/dynamo/test_utils.py b/tests/py/dynamo/converters/harness.py similarity index 97% rename from py/torch_tensorrt/dynamo/test_utils.py rename to tests/py/dynamo/converters/harness.py index a3d742c70a..5634e37a30 100644 --- a/py/torch_tensorrt/dynamo/test_utils.py +++ b/tests/py/dynamo/converters/harness.py @@ -22,8 +22,8 @@ from torch_tensorrt.fx.passes.pass_utils import chain_passes # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry -from torch_tensorrt.dynamo.conversion.trt_interpreter import TRTInterpreter -from torch_tensorrt.dynamo.runtime._PythonTorchTRTModule import PythonTorchTRTModule +from torch_tensorrt.dynamo.conversion import TRTInterpreter +from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule from torch_tensorrt import Input @@ -83,7 +83,7 @@ def run_test( interpreter_result = interpreter.run(precision=precision) sec = time.perf_counter() - start _LOGGER.info(f"Interpreter run time(s): {sec}") - trt_mod = PythonTorchTRTModule( + trt_mod = PythonTorchTensorRTModule( interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names, @@ -159,7 +159,7 @@ def run_test_custom_compare_results( interpreter_result = interpreter.run( precision=torch.half if fp16_mode else torch.float ) - trt_mod = PythonTorchTRTModule( + trt_mod = PythonTorchTensorRTModule( interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names, diff --git a/tests/py/dynamo/converters/test_adaptive_avgpool_aten.py b/tests/py/dynamo/converters/test_adaptive_avgpool_aten.py index 68ce24c20f..7ef8fe030b 100644 --- a/tests/py/dynamo/converters/test_adaptive_avgpool_aten.py +++ b/tests/py/dynamo/converters/test_adaptive_avgpool_aten.py @@ -1,9 +1,10 @@ import torch from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase from torch_tensorrt import Input +from harness import DispatchTestCase + class TestAdaptiveAvgPoolConverter(DispatchTestCase): def test_adaptive_avgpool_mean(self): diff --git a/tests/py/dynamo/converters/test_batchnorm_aten.py b/tests/py/dynamo/converters/test_batchnorm_aten.py index c39f14abfe..bc88cdefe6 100644 --- a/tests/py/dynamo/converters/test_batchnorm_aten.py +++ b/tests/py/dynamo/converters/test_batchnorm_aten.py @@ -1,6 +1,6 @@ import torch from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_binary_ops_aten.py b/tests/py/dynamo/converters/test_binary_ops_aten.py index 19fa02721c..75823aec80 100644 --- a/tests/py/dynamo/converters/test_binary_ops_aten.py +++ b/tests/py/dynamo/converters/test_binary_ops_aten.py @@ -6,7 +6,7 @@ from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input NEED_TEST_BOTH_CONSTANTS_CASE = True diff --git a/tests/py/dynamo/converters/test_cat_aten.py b/tests/py/dynamo/converters/test_cat_aten.py index d9d107de89..f89f77f3c4 100644 --- a/tests/py/dynamo/converters/test_cat_aten.py +++ b/tests/py/dynamo/converters/test_cat_aten.py @@ -2,7 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_clamp_aten.py b/tests/py/dynamo/converters/test_clamp_aten.py index 05716c1657..0f2ea0f41f 100644 --- a/tests/py/dynamo/converters/test_clamp_aten.py +++ b/tests/py/dynamo/converters/test_clamp_aten.py @@ -1,7 +1,7 @@ import torch from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_convolution_aten.py b/tests/py/dynamo/converters/test_convolution_aten.py index a906d70d43..fbe9d2aa5c 100644 --- a/tests/py/dynamo/converters/test_convolution_aten.py +++ b/tests/py/dynamo/converters/test_convolution_aten.py @@ -1,7 +1,7 @@ import torch from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_elu_aten.py b/tests/py/dynamo/converters/test_elu_aten.py index dfaf2db5a6..9935729e65 100644 --- a/tests/py/dynamo/converters/test_elu_aten.py +++ b/tests/py/dynamo/converters/test_elu_aten.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_embedding_aten.py b/tests/py/dynamo/converters/test_embedding_aten.py index 4d36478303..c27ee2173e 100644 --- a/tests/py/dynamo/converters/test_embedding_aten.py +++ b/tests/py/dynamo/converters/test_embedding_aten.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from parameterized import param, parameterized from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_expand_aten.py b/tests/py/dynamo/converters/test_expand_aten.py index 1b1f3d1c14..bb5b93304a 100644 --- a/tests/py/dynamo/converters/test_expand_aten.py +++ b/tests/py/dynamo/converters/test_expand_aten.py @@ -2,7 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase class TestExpandConverter(DispatchTestCase): diff --git a/tests/py/dynamo/converters/test_gelu_aten.py b/tests/py/dynamo/converters/test_gelu_aten.py index c62a028c0e..c46fbdfa08 100644 --- a/tests/py/dynamo/converters/test_gelu_aten.py +++ b/tests/py/dynamo/converters/test_gelu_aten.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_hardtanh_aten.py b/tests/py/dynamo/converters/test_hardtanh_aten.py index 8401dd17a9..5f042e6cf8 100644 --- a/tests/py/dynamo/converters/test_hardtanh_aten.py +++ b/tests/py/dynamo/converters/test_hardtanh_aten.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_layer_norm_aten.py b/tests/py/dynamo/converters/test_layer_norm_aten.py index a4766bd030..05ae8888ec 100644 --- a/tests/py/dynamo/converters/test_layer_norm_aten.py +++ b/tests/py/dynamo/converters/test_layer_norm_aten.py @@ -1,6 +1,6 @@ import torch from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_leaky_relu_aten.py b/tests/py/dynamo/converters/test_leaky_relu_aten.py index aa3d56641b..70e6b30723 100644 --- a/tests/py/dynamo/converters/test_leaky_relu_aten.py +++ b/tests/py/dynamo/converters/test_leaky_relu_aten.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_linear_aten.py b/tests/py/dynamo/converters/test_linear_aten.py index b9e3261642..a435313908 100644 --- a/tests/py/dynamo/converters/test_linear_aten.py +++ b/tests/py/dynamo/converters/test_linear_aten.py @@ -1,7 +1,7 @@ import torch from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase class TestLinearConverter(DispatchTestCase): diff --git a/tests/py/dynamo/converters/test_matmul_aten.py b/tests/py/dynamo/converters/test_matmul_aten.py index f01325fb10..665578030d 100644 --- a/tests/py/dynamo/converters/test_matmul_aten.py +++ b/tests/py/dynamo/converters/test_matmul_aten.py @@ -2,7 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase class TestMatMulConverter(DispatchTestCase): diff --git a/tests/py/dynamo/converters/test_mean_aten.py b/tests/py/dynamo/converters/test_mean_aten.py index fe31d90a24..56a39c7b63 100644 --- a/tests/py/dynamo/converters/test_mean_aten.py +++ b/tests/py/dynamo/converters/test_mean_aten.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_permutation_aten.py b/tests/py/dynamo/converters/test_permutation_aten.py index f9d614ae68..a5a8849576 100644 --- a/tests/py/dynamo/converters/test_permutation_aten.py +++ b/tests/py/dynamo/converters/test_permutation_aten.py @@ -2,7 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_relu_aten.py b/tests/py/dynamo/converters/test_relu_aten.py index 08ab04014d..e32b891315 100644 --- a/tests/py/dynamo/converters/test_relu_aten.py +++ b/tests/py/dynamo/converters/test_relu_aten.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_reshape_aten.py b/tests/py/dynamo/converters/test_reshape_aten.py index 1df71abc1a..e1f5167015 100644 --- a/tests/py/dynamo/converters/test_reshape_aten.py +++ b/tests/py/dynamo/converters/test_reshape_aten.py @@ -4,7 +4,7 @@ import torch from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_rsqrt_aten.py b/tests/py/dynamo/converters/test_rsqrt_aten.py index 5770e697fc..909aa6e5da 100644 --- a/tests/py/dynamo/converters/test_rsqrt_aten.py +++ b/tests/py/dynamo/converters/test_rsqrt_aten.py @@ -2,7 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_select_aten.py b/tests/py/dynamo/converters/test_select_aten.py index 049cd9c7e6..d6a365e30d 100644 --- a/tests/py/dynamo/converters/test_select_aten.py +++ b/tests/py/dynamo/converters/test_select_aten.py @@ -1,7 +1,7 @@ import torch from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_selu_aten.py b/tests/py/dynamo/converters/test_selu_aten.py index 7fb6afda76..701c4761d3 100644 --- a/tests/py/dynamo/converters/test_selu_aten.py +++ b/tests/py/dynamo/converters/test_selu_aten.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_sigmoid_aten.py b/tests/py/dynamo/converters/test_sigmoid_aten.py index 37bbea1730..fcd31583de 100644 --- a/tests/py/dynamo/converters/test_sigmoid_aten.py +++ b/tests/py/dynamo/converters/test_sigmoid_aten.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_slice_aten.py b/tests/py/dynamo/converters/test_slice_aten.py index 86de36d351..a81533ac25 100644 --- a/tests/py/dynamo/converters/test_slice_aten.py +++ b/tests/py/dynamo/converters/test_slice_aten.py @@ -1,7 +1,7 @@ import torch from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_softmax_aten.py b/tests/py/dynamo/converters/test_softmax_aten.py index 8d33f3ebe0..3e0095265c 100644 --- a/tests/py/dynamo/converters/test_softmax_aten.py +++ b/tests/py/dynamo/converters/test_softmax_aten.py @@ -1,6 +1,6 @@ import torch from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_squeeze_aten.py b/tests/py/dynamo/converters/test_squeeze_aten.py index 152fe86300..e4a9943cba 100644 --- a/tests/py/dynamo/converters/test_squeeze_aten.py +++ b/tests/py/dynamo/converters/test_squeeze_aten.py @@ -2,7 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_tanh_aten.py b/tests/py/dynamo/converters/test_tanh_aten.py index f9aa94a7bc..4ff48a8d2e 100644 --- a/tests/py/dynamo/converters/test_tanh_aten.py +++ b/tests/py/dynamo/converters/test_tanh_aten.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_unsqueeze_aten.py b/tests/py/dynamo/converters/test_unsqueeze_aten.py index db8ae7151f..4280456d2f 100644 --- a/tests/py/dynamo/converters/test_unsqueeze_aten.py +++ b/tests/py/dynamo/converters/test_unsqueeze_aten.py @@ -3,7 +3,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase from torch_tensorrt import Input diff --git a/tests/py/dynamo/converters/test_where_aten.py b/tests/py/dynamo/converters/test_where_aten.py index 39ba0500b9..ddeb269ee9 100644 --- a/tests/py/dynamo/converters/test_where_aten.py +++ b/tests/py/dynamo/converters/test_where_aten.py @@ -2,7 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from harness import DispatchTestCase class TestWhereConverter(DispatchTestCase): diff --git a/tests/py/ts/BUILD b/tests/py/ts/BUILD index 9d19848207..98db68fc44 100644 --- a/tests/py/ts/BUILD +++ b/tests/py/ts/BUILD @@ -14,8 +14,8 @@ config_setting( py_test( name = "test_api", srcs = [ - "test_api.py", "model_test_case.py", + "test_api.py", ] + select({ ":aarch64_linux": [ "test_api_dla.py", diff --git a/tests/py/ts/api/test_classes.py b/tests/py/ts/api/test_classes.py index 835257fc58..01c805d9a1 100644 --- a/tests/py/ts/api/test_classes.py +++ b/tests/py/ts/api/test_classes.py @@ -104,7 +104,9 @@ def test_infer_from_example_tensor(self): example_tensor = torch.randn(shape).half() i = torchtrt.Input.from_tensor(example_tensor) - ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + ts_i = torchtrt.ts.TorchScriptInput( + shape=i.shape, dtype=i.dtype, format=i.format + ) self.assertTrue(self._verify_correctness(ts_i, target)) def test_static_shape(self): @@ -120,27 +122,39 @@ def test_static_shape(self): } i = torchtrt.Input(shape) - ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + ts_i = torchtrt.ts.TorchScriptInput( + shape=i.shape, dtype=i.dtype, format=i.format + ) self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(tuple(shape)) - ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + ts_i = torchtrt.ts.TorchScriptInput( + shape=i.shape, dtype=i.dtype, format=i.format + ) self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(torch.randn(shape).shape) - ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + ts_i = torchtrt.ts.TorchScriptInput( + shape=i.shape, dtype=i.dtype, format=i.format + ) self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(shape=shape) - ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + ts_i = torchtrt.ts.TorchScriptInput( + shape=i.shape, dtype=i.dtype, format=i.format + ) self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(shape=tuple(shape)) - ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + ts_i = torchtrt.ts.TorchScriptInput( + shape=i.shape, dtype=i.dtype, format=i.format + ) self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(shape=torch.randn(shape).shape) - ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + ts_i = torchtrt.ts.TorchScriptInput( + shape=i.shape, dtype=i.dtype, format=i.format + ) self.assertTrue(self._verify_correctness(ts_i, target)) def test_data_type(self): @@ -156,11 +170,15 @@ def test_data_type(self): } i = torchtrt.Input(shape, dtype=torchtrt.dtype.half) - ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + ts_i = torchtrt.ts.TorchScriptInput( + shape=i.shape, dtype=i.dtype, format=i.format + ) self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(shape, dtype=torch.half) - ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + ts_i = torchtrt.ts.TorchScriptInput( + shape=i.shape, dtype=i.dtype, format=i.format + ) self.assertTrue(self._verify_correctness(ts_i, target)) def test_tensor_format(self): @@ -176,11 +194,15 @@ def test_tensor_format(self): } i = torchtrt.Input(shape, format=torchtrt.TensorFormat.channels_last) - ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + ts_i = torchtrt.ts.TorchScriptInput( + shape=i.shape, dtype=i.dtype, format=i.format + ) self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(shape, format=torch.channels_last) - ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + ts_i = torchtrt.ts.TorchScriptInput( + shape=i.shape, dtype=i.dtype, format=i.format + ) self.assertTrue(self._verify_correctness(ts_i, target)) def test_dynamic_shape(self): @@ -200,7 +222,7 @@ def test_dynamic_shape(self): i = torchtrt.Input( min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape ) - ts_i = torchtrt.ts.TSInput( + ts_i = torchtrt.ts.TorchScriptInput( min_shape=i.shape["min_shape"], opt_shape=i.shape["opt_shape"], max_shape=i.shape["max_shape"], @@ -214,7 +236,7 @@ def test_dynamic_shape(self): opt_shape=tuple(opt_shape), max_shape=tuple(max_shape), ) - ts_i = torchtrt.ts.TSInput( + ts_i = torchtrt.ts.TorchScriptInput( min_shape=i.shape["min_shape"], opt_shape=i.shape["opt_shape"], max_shape=i.shape["max_shape"], @@ -229,7 +251,7 @@ def test_dynamic_shape(self): opt_shape=tensor_shape(opt_shape), max_shape=tensor_shape(max_shape), ) - ts_i = torchtrt.ts.TSInput( + ts_i = torchtrt.ts.TorchScriptInput( min_shape=i.shape["min_shape"], opt_shape=i.shape["opt_shape"], max_shape=i.shape["max_shape"], diff --git a/tests/py/ts/api/test_collections.py b/tests/py/ts/api/test_collections.py index 64f46fa3e9..eab67679ed 100644 --- a/tests/py/ts/api/test_collections.py +++ b/tests/py/ts/api/test_collections.py @@ -23,7 +23,6 @@ def find_repo_root(max_depth=10): class TestStandardTensorInput(unittest.TestCase): def test_compile(self): - self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( torch.jit.load(MODULE_DIR + "/standard_tensor_input_scripted.jit.pt") @@ -52,7 +51,6 @@ def test_compile(self): class TestStandardTensorInputLong(unittest.TestCase): def test_compile(self): - self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( torch.jit.load(MODULE_DIR + "/standard_tensor_input_scripted.jit.pt") @@ -82,7 +80,6 @@ def test_compile(self): class TestStandardTensorInputDomain(unittest.TestCase): def test_compile(self): - self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( torch.jit.load(MODULE_DIR + "/standard_tensor_input_scripted.jit.pt") @@ -111,7 +108,6 @@ def test_compile(self): class TestTupleInput(unittest.TestCase): def test_compile(self): - self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( torch.jit.load(MODULE_DIR + "/tuple_input_scripted.jit.pt") @@ -140,7 +136,6 @@ def test_compile(self): class TestListInput(unittest.TestCase): def test_compile(self): - self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( torch.jit.load(MODULE_DIR + "/list_input_scripted.jit.pt").eval().to("cuda") @@ -167,7 +162,6 @@ def test_compile(self): class TestTupleInputOutput(unittest.TestCase): def test_compile(self): - self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( torch.jit.load(MODULE_DIR + "/tuple_input_output_scripted.jit.pt") @@ -187,7 +181,7 @@ def test_compile(self): trt_mod = torchtrt.ts.compile(self.model, **compile_spec) trt_out = trt_mod((self.input, self.input)) pyt_out = self.model((self.input, self.input)) - for (t, p) in zip(trt_out, pyt_out): + for t, p in zip(trt_out, pyt_out): cos_sim = cosine_similarity(t, p) self.assertTrue( cos_sim > COSINE_THRESHOLD, @@ -215,7 +209,7 @@ def test_compile_full_compilation(self): trt_mod = torchtrt.ts.compile(self.model, **compile_spec) trt_out = trt_mod((self.input, self.input)) pyt_out = self.model((self.input, self.input)) - for (t, p) in zip(trt_out, pyt_out): + for t, p in zip(trt_out, pyt_out): cos_sim = cosine_similarity(t, p) self.assertTrue( cos_sim > COSINE_THRESHOLD, @@ -225,7 +219,6 @@ def test_compile_full_compilation(self): class TestListInputOutput(unittest.TestCase): def test_compile(self): - self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( torch.jit.load(MODULE_DIR + "/list_input_output_scripted.jit.pt") @@ -246,7 +239,7 @@ def test_compile(self): trt_out = trt_mod((self.input, self.input)) pyt_out = self.model((self.input, self.input)) - for (t, p) in zip(trt_out, pyt_out): + for t, p in zip(trt_out, pyt_out): cos_sim = cosine_similarity(t, p) self.assertTrue( cos_sim > COSINE_THRESHOLD, @@ -254,7 +247,6 @@ def test_compile(self): ) def test_compile_full_compilation(self): - self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( torch.jit.load(MODULE_DIR + "/list_input_output_scripted.jit.pt") @@ -276,7 +268,7 @@ def test_compile_full_compilation(self): trt_out = trt_mod((self.input, self.input)) pyt_out = self.model((self.input, self.input)) - for (t, p) in zip(trt_out, pyt_out): + for t, p in zip(trt_out, pyt_out): cos_sim = cosine_similarity(t, p) self.assertTrue( cos_sim > COSINE_THRESHOLD, @@ -286,7 +278,6 @@ def test_compile_full_compilation(self): class TestListInputTupleOutput(unittest.TestCase): def test_compile(self): - self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( torch.jit.load(MODULE_DIR + "/list_input_tuple_output_scripted.jit.pt") @@ -306,7 +297,7 @@ def test_compile(self): trt_mod = torchtrt.ts.compile(self.model, **compile_spec) trt_out = trt_mod((self.input, self.input)) pyt_out = self.model((self.input, self.input)) - for (t, p) in zip(trt_out, pyt_out): + for t, p in zip(trt_out, pyt_out): cos_sim = cosine_similarity(t, p) self.assertTrue( cos_sim > COSINE_THRESHOLD, @@ -314,7 +305,6 @@ def test_compile(self): ) def test_compile_full_compilation(self): - self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( torch.jit.load(MODULE_DIR + "/list_input_tuple_output_scripted.jit.pt") @@ -335,7 +325,7 @@ def test_compile_full_compilation(self): trt_mod = torchtrt.ts.compile(self.model, **compile_spec) trt_out = trt_mod((self.input, self.input)) pyt_out = self.model((self.input, self.input)) - for (t, p) in zip(trt_out, pyt_out): + for t, p in zip(trt_out, pyt_out): cos_sim = cosine_similarity(t, p) self.assertTrue( cos_sim > COSINE_THRESHOLD, diff --git a/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py b/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py index 79c19dadbf..c5a84f301d 100644 --- a/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py +++ b/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py @@ -51,7 +51,6 @@ def compute_accuracy(testing_dataloader, model): class TestAccuracy(unittest.TestCase): def test_compile_script(self): - self.model = ( torch.jit.load(MODULE_DIR + "/trained_vgg16.jit.pt").eval().to("cuda") ) diff --git a/tests/util/BUILD b/tests/util/BUILD index a31cda26f1..87b7da9cda 100644 --- a/tests/util/BUILD +++ b/tests/util/BUILD @@ -29,26 +29,26 @@ cc_library( "util.h", ], deps = [ - "@tensorrt//:nvinfer", "@googletest//:gtest_main", + "@tensorrt//:nvinfer", ] + select({ ":use_pre_cxx11_abi": [ - "@libtorch_pre_cxx11_abi//:libtorch", "@libtorch_pre_cxx11_abi//:caffe2", + "@libtorch_pre_cxx11_abi//:libtorch", ], "//conditions:default": [ - "@libtorch//:libtorch", + "@libtorch", "@libtorch//:caffe2", ], }) + select({ ":ci_build_testing": [ - "@torch_tensorrt//:torch_tensorrt", + "@torch_tensorrt", "@torch_tensorrt//:torch_tensorrt_core_hdrs", ], "//conditions:default": [ - "//cpp:torch_tensorrt", "//core/conversion", "//core/util:prelude", + "//cpp:torch_tensorrt", ], }), ) diff --git a/third_party/libtorch/BUILD b/third_party/libtorch/BUILD index dac3a2eaf2..1284bba477 100644 --- a/third_party/libtorch/BUILD +++ b/third_party/libtorch/BUILD @@ -21,14 +21,14 @@ cc_library( srcs = select({ ":windows": [ "lib/torch.lib", - "lib/torch_cuda.lib", "lib/torch_cpu.lib", + "lib/torch_cuda.lib", "lib/torch_global_deps.dll", ], "//conditions:default": [ "lib/libtorch.so", - "lib/libtorch_cuda.so", "lib/libtorch_cpu.so", + "lib/libtorch_cuda.so", "lib/libtorch_global_deps.so", ], }), diff --git a/third_party/tensorrt/local/BUILD b/third_party/tensorrt/local/BUILD index 5d0842507f..9cbe98a41e 100644 --- a/third_party/tensorrt/local/BUILD +++ b/third_party/tensorrt/local/BUILD @@ -134,30 +134,30 @@ cc_library( hdrs = select({ ":aarch64_linux": [ "include/aarch64-linux-gnu/NvCaffeParser.h", + "include/aarch64-linux-gnu/NvOnnxConfig.h", "include/aarch64-linux-gnu/NvOnnxParser.h", "include/aarch64-linux-gnu/NvOnnxParserRuntime.h", - "include/aarch64-linux-gnu/NvOnnxConfig.h", "include/aarch64-linux-gnu/NvUffParser.h", ], ":ci_rhel_x86_64_linux": [ "include/NvCaffeParser.h", + "include/NvOnnxConfig.h", "include/NvOnnxParser.h", "include/NvOnnxParserRuntime.h", - "include/NvOnnxConfig.h", "include/NvUffParser.h", ], ":windows": [ "include/NvCaffeParser.h", + "include/NvOnnxConfig.h", "include/NvOnnxParser.h", "include/NvOnnxParserRuntime.h", - "include/NvOnnxConfig.h", "include/NvUffParser.h", ], "//conditions:default": [ "include/x86_64-linux-gnu/NvCaffeParser.h", + "include/x86_64-linux-gnu/NvOnnxConfig.h", "include/x86_64-linux-gnu/NvOnnxParser.h", "include/x86_64-linux-gnu/NvOnnxParserRuntime.h", - "include/x86_64-linux-gnu/NvOnnxConfig.h", "include/x86_64-linux-gnu/NvUffParser.h", ], }), @@ -197,24 +197,24 @@ cc_library( name = "nvonnxparser_headers", hdrs = select({ ":aarch64_linux": [ + "include/aarch64-linux-gnu/NvOnnxConfig.h", "include/aarch64-linux-gnu/NvOnnxParser.h", "include/aarch64-linux-gnu/NvOnnxParserRuntime.h", - "include/aarch64-linux-gnu/NvOnnxConfig.h", ], ":ci_rhel_x86_64_linux": [ + "include/NvOnnxConfig.h", "include/NvOnnxParser.h", "include/NvOnnxParserRuntime.h", - "include/NvOnnxConfig.h", ], ":windows": [ + "include/NvOnnxConfig.h", "include/NvOnnxParser.h", "include/NvOnnxParserRuntime.h", - "include/NvOnnxConfig.h", ], "//conditions:default": [ + "include/x86_64-linux-gnu/NvOnnxConfig.h", "include/x86_64-linux-gnu/NvOnnxParser.h", "include/x86_64-linux-gnu/NvOnnxParserRuntime.h", - "include/x86_64-linux-gnu/NvOnnxConfig.h", ], }), includes = select({ diff --git a/tools/perf/hub.py b/tools/perf/hub.py index a1f032212b..4cedaabded 100644 --- a/tools/perf/hub.py +++ b/tools/perf/hub.py @@ -145,7 +145,6 @@ def main(): # Creating an empty manifest file for overwriting post setup os.system("touch {}".format(MANIFEST_FILE)) else: - # Load manifest if already exists with open(MANIFEST_FILE, "r") as f: manifest = json.load(f) diff --git a/tools/perf/perf_run.py b/tools/perf/perf_run.py index 00ddbabd22..65729008aa 100644 --- a/tools/perf/perf_run.py +++ b/tools/perf/perf_run.py @@ -33,6 +33,7 @@ WARMUP_ITER = 10 results = [] + # YAML Parser class for parsing the run configurations class ConfigParser: def __init__(self, config_file):