Skip to content

Commit

Permalink
fix: Upgrade Torch version, enable options
Browse files Browse the repository at this point in the history
- Upgrade Torch version across the stack
- Update Dynamo sample with advanced usage to indicate usage of new
`options` argument in `torch.compile`
- Enable options argument in `torch.compile` including improved input
handling in the default torch_tensorrt backend
- ResNet example now features `torch_tensorrt.dynamo.compile`, while
transformers example features `torch_tensorrt.compile(...,
ir="dynamo_compile", ...)`
- Fix bugs in core runtime and `TRTInterpreter` to address issues
arising with latest PyTorch distribution
- Add feature in `TRTInterpreter` to specify output data types
- Add `pass_through_build_failures` argument to
`torch_tensorrt.dynamo.torch_compile` frontend
  • Loading branch information
gs-olive committed Jun 3, 2023
1 parent 6d2a26b commit 4ba253f
Show file tree
Hide file tree
Showing 14 changed files with 93 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# +------------------------------------------------------------+
# Enable colorful output of GCC
build --cxxopt="-fdiagnostics-color=always"
build --cxxopt='-std=c++14'
build --cxxopt='-std=c++17'
#build --linkopt="-Wl,--no-as-needed"


Expand Down
8 changes: 4 additions & 4 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,10 @@ commands:
parameters:
torch-build:
type: string
default: "2.1.0.dev20230419+cu118"
default: "2.1.0.dev20230601+cu118"
torchvision-build:
type: string
default: "0.16.0.dev20230419+cu118"
default: "0.16.0.dev20230601+cu118"
torch-build-index:
type: string
default: "https://download.pytorch.org/whl/nightly/cu118"
Expand Down Expand Up @@ -1352,10 +1352,10 @@ parameters:
# Nightly platform config
torch-build:
type: string
default: "2.1.0.dev20230419+cu118"
default: "2.1.0.dev20230601+cu118"
torchvision-build:
type: string
default: "0.16.0.dev20230419+cu118"
default: "0.16.0.dev20230601+cu118"
torch-build-index:
type: string
default: "https://download.pytorch.org/whl/nightly/cu118"
Expand Down
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
cmake_minimum_required(VERSION 3.17)
project(Torch-TensorRT LANGUAGES CXX)

# use c++14 like PyTorch
set(CMAKE_CXX_STANDARD 14)
# use c++17 like PyTorch
set(CMAKE_CXX_STANDARD 17)

# Build the libraries with -fPIC
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd
These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.

- Bazel 5.2.0
- Libtorch 2.1.0.dev20230419 (built with CUDA 11.8)
- Libtorch 2.1.0.dev20230601 (built with CUDA 11.8)
- CUDA 11.8
- cuDNN 8.8.0
- TensorRT 8.6.1
Expand Down
8 changes: 4 additions & 4 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ new_local_repository(
http_archive(
name = "libtorch",
build_file = "@//third_party/libtorch:BUILD",
sha256 = "1a526a9cd19c1015674d26921dbb94bcd2d632a6f9c431a21c43f4e24768d834",
sha256 = "c8407ae3462c344ae3814e82023e22ece759ebe75023f35bdf62e9c0a7e79035",
strip_prefix = "libtorch",
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230601%2Bcu118.zip"],
)

http_archive(
name = "libtorch_pre_cxx11_abi",
build_file = "@//third_party/libtorch:BUILD",
sha256 = "60c5912a5085a6a7073b3804b10d41d6cc054693bbeb7a45e0247050c2837bac",
sha256 = "76f983bd6d784cc0a95c679034d297abe36911c16b2188498b13a9028177e28e",
strip_prefix = "libtorch",
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230601%2Bcu118.zip"],
)

# Download these tarballs manually from the NVIDIA website
Expand Down
3 changes: 2 additions & 1 deletion core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ TRTEngine::TRTEngine(
for (size_t pyt_idx = 0; pyt_idx < inputs_size; pyt_idx++) {
auto binding_name = _in_binding_names[pyt_idx];
auto trt_idx = cuda_engine->getBindingIndex(binding_name.c_str());
std::string engine_binded_name = cuda_engine->getIOTensorName(pyt_idx);
std::string engine_binded_name = cuda_engine->getIOTensorName(trt_idx);

TORCHTRT_CHECK(
(binding_name == engine_binded_name),
"Could not find a TensorRT engine binding for input named " << binding_name);
Expand Down
38 changes: 23 additions & 15 deletions examples/dynamo/dynamo_compile_advanced_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import torch
from torch_tensorrt.dynamo.backend import create_backend
from torch_tensorrt.fx.lower_setting import LowerPrecision

# %%
Expand Down Expand Up @@ -39,15 +38,19 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):

# Next, we compile the model using torch.compile
# For the default settings, we can simply call torch.compile
# with the backend "tensorrt", and run the model on an
# with the backend "torch_tensorrt", and run the model on an
# input to cause compilation, as so:
optimized_model = torch.compile(model, backend="tensorrt")
optimized_model = torch.compile(model, backend="torch_tensorrt")
optimized_model(*sample_inputs)

# %%
# Compilation with `torch.compile` Using Custom Settings
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# First, we use Torch utilities to clean up the workspace
# after the previous compile invocation
torch._dynamo.reset()

# Define sample half inputs and initialize model
sample_inputs_half = [
torch.rand((5, 7)).half().cuda(),
Expand All @@ -58,20 +61,25 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
# %%

# If we want to customize certain options in the backend,
# but still use the torch.compile call directly, we can call the
# convenience/helper function create_backend to create a custom backend
# which has been pre-populated with certain keys
custom_backend = create_backend(
lower_precision=LowerPrecision.FP16,
debug=True,
min_block_size=2,
torch_executed_ops={},
optimization_level=4,
use_experimental_rt=True,
)
# but still use the torch.compile call directly, we can provide
# custom options to the backend via the "options" keyword
# which takes in a dictionary mapping options to values.
#
# For accepted backend options, see the CompilationSettings dataclass:
# py/torch_tensorrt/dynamo/backend/_settings.py
backend_kwargs = {
"lower_precision": LowerPrecision.FP16,
"debug": True,
"min_block_size": 2,
"torch_executed_ops": {"torch.ops.aten.sub.Tensor"},
"optimization_level": 4,
"use_experimental_rt": True,
}

# Run the model on an input to cause compilation, as so:
optimized_model_custom = torch.compile(model_half, backend=custom_backend)
optimized_model_custom = torch.compile(
model_half, backend="torch_tensorrt", options=backend_kwargs
)
optimized_model_custom(*sample_inputs_half)

# %%
Expand Down
4 changes: 2 additions & 2 deletions py/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ numpy
packaging
pybind11==2.6.2
--extra-index-url https://download.pytorch.org/whl/nightly/cu118
torch==2.1.0.dev20230419+cu118
torchvision==0.16.0.dev20230419+cu118
torch==2.1.0.dev20230601+cu118
torchvision==0.16.0.dev20230601+cu118
--extra-index-url https://pypi.ngc.nvidia.com
tensorrt==8.6.1
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def compile(
min_block_size=MIN_BLOCK_SIZE,
torch_executed_ops=[],
torch_executed_modules=[],
pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES,
max_aux_streams=MAX_AUX_STREAMS,
version_compatible=VERSION_COMPATIBLE,
optimization_level=OPTIMIZATION_LEVEL,
Expand Down Expand Up @@ -94,6 +95,7 @@ def compile(
workspace_size=workspace_size,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
pass_through_build_failures=pass_through_build_failures,
max_aux_streams=max_aux_streams,
version_compatible=version_compatible,
optimization_level=optimization_level,
Expand Down
18 changes: 18 additions & 0 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Sequence
import torch
from functools import partial
from dataclasses import replace, fields
import torch._dynamo as td

from torch_tensorrt.dynamo.backend._settings import CompilationSettings
Expand All @@ -28,7 +29,24 @@ def torch_tensorrt_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
**kwargs
):
# If the user specifies keyword args, overwrite those fields in settings
# Validate all specified kwargs to ensure they are true fields of the dataclass
#
# Note: kwargs provided by torch.compile are wrapped in the "options" key
if kwargs:
if "options" in kwargs and len(kwargs) == 1:
kwargs = kwargs["options"]

valid_attrs = {attr.name for attr in fields(settings)}
valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs}
settings = replace(settings, **valid_kwargs)

# Enable debug/verbose mode if requested
if settings.debug:
logger.setLevel(logging.DEBUG)

DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend

return DEFAULT_BACKEND(gm, sample_inputs, settings=settings)
Expand Down
10 changes: 10 additions & 0 deletions py/torch_tensorrt/dynamo/backend/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,21 @@ def convert_module(
Returns:
TRTModule or TRTModuleNext
"""
# Specify module output data types to ensure TRT output types agree with
# that of the equivalent Torch module
module_outputs = module(*inputs)

if not isinstance(module_outputs, (list, tuple)):
module_outputs = [module_outputs]

output_dtypes = list(output.dtype for output in module_outputs)

interpreter = TRTInterpreter(
module,
InputTensorSpec.from_tensors(inputs),
explicit_batch_dimension=True,
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
output_dtypes=output_dtypes,
)

interpreter_result = interpreter.run(
Expand Down
17 changes: 16 additions & 1 deletion py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
explicit_batch_dimension: bool = True,
explicit_precision: bool = False,
logger_level=None,
output_dtypes=None,
):
super().__init__(module)

Expand Down Expand Up @@ -79,6 +80,9 @@ def __init__(
trt.tensorrt.ITensor, TensorMetadata
] = dict()

# Data types for TRT Module output Tensors
self.output_dtypes = output_dtypes

def validate_input_specs(self):
for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
if not self.network.has_implicit_batch_dimension:
Expand Down Expand Up @@ -179,13 +183,17 @@ def run(
algorithm_selector: set up algorithm selection for certain layer
timing_cache: enable timing cache for TensorRT
profiling_verbosity: TensorRT logging level
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
version_compatible: Provide version forward-compatibility for engine plan files
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
searching for more optimization options. TRT defaults to 3
Return:
TRTInterpreterResult
"""
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)

# For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and
# force_fp32_output=False.
# force_fp32_output=False. Overriden by specifying output_dtypes
self.output_fp16 = (
not force_fp32_output and lower_precision == LowerPrecision.FP16
)
Expand Down Expand Up @@ -373,6 +381,11 @@ def output(self, target, args, kwargs):
if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
raise RuntimeError("TensorRT requires all outputs to be Tensor!")

if self.output_dtypes is not None and len(self.output_dtypes) != len(outputs):
raise RuntimeError(
f"Specified output dtypes ({len(self.output_dtypes)}) differ from number of outputs ({len(outputs)})"
)

for i, output in enumerate(outputs):
if any(
op_name in output.name.split("_")
Expand All @@ -397,6 +410,8 @@ def output(self, target, args, kwargs):
self.network.mark_output(output)
if output_bool:
output.dtype = trt.bool
elif self.output_dtypes is not None:
output.dtype = torch_dtype_to_trt(self.output_dtypes[i])
elif self.output_fp16 and output.dtype == trt.float32:
output.dtype = trt.float16
self._output_names.append(name)
8 changes: 4 additions & 4 deletions toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,17 @@ new_local_repository(
http_archive(
name = "libtorch",
build_file = "@//third_party/libtorch:BUILD",
sha256 = "1a526a9cd19c1015674d26921dbb94bcd2d632a6f9c431a21c43f4e24768d834",
sha256 = "c8407ae3462c344ae3814e82023e22ece759ebe75023f35bdf62e9c0a7e79035",
strip_prefix = "libtorch",
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230601%2Bcu118.zip"],
)

http_archive(
name = "libtorch_pre_cxx11_abi",
build_file = "@//third_party/libtorch:BUILD",
sha256 = "60c5912a5085a6a7073b3804b10d41d6cc054693bbeb7a45e0247050c2837bac",
sha256 = "76f983bd6d784cc0a95c679034d297abe36911c16b2188498b13a9028177e28e",
strip_prefix = "libtorch",
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230601%2Bcu118.zip"],
)

####################################################################################
Expand Down
8 changes: 4 additions & 4 deletions toolchains/ci_workspaces/WORKSPACE.x86_64.release.ubuntu
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,17 @@ new_local_repository(
http_archive(
name = "libtorch",
build_file = "@//third_party/libtorch:BUILD",
sha256 = "1a526a9cd19c1015674d26921dbb94bcd2d632a6f9c431a21c43f4e24768d834",
sha256 = "c8407ae3462c344ae3814e82023e22ece759ebe75023f35bdf62e9c0a7e79035",
strip_prefix = "libtorch",
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230601%2Bcu118.zip"],
)

http_archive(
name = "libtorch_pre_cxx11_abi",
build_file = "@//third_party/libtorch:BUILD",
sha256 = "60c5912a5085a6a7073b3804b10d41d6cc054693bbeb7a45e0247050c2837bac",
sha256 = "76f983bd6d784cc0a95c679034d297abe36911c16b2188498b13a9028177e28e",
strip_prefix = "libtorch",
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230601%2Bcu118.zip"],
)

####################################################################################
Expand Down

0 comments on commit 4ba253f

Please sign in to comment.