diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 4c8642b7b1..6e7ba66634 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -154,13 +154,13 @@ def compile( dynamic_batch=False, **kwargs, ) - elif target_ir == _IRType.dynamo: + elif target_ir == _IRType.dynamo or target_ir == _IRType.torch_compile: return torch_tensorrt.dynamo.compile( - module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs - ) - elif target_ir == _IRType.torch_compile: - return torch_tensorrt.dynamo.backend.compile( - module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs + module, + inputs=inputs, + enabled_precisions=enabled_precisions, + ir=target_ir.name, + **kwargs, ) else: raise RuntimeError("Module is an unknown format or the ir requested is unknown") diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index abda437bb3..c88162a31a 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -1,167 +1,2 @@ -import torch -import logging -import collections.abc -import torch_tensorrt -from functools import partial - -from typing import Any, Optional, Sequence -from torch_tensorrt import EngineCapability, Device -from torch_tensorrt.fx.utils import LowerPrecision - -from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device -from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend -from torch_tensorrt.dynamo._defaults import ( - PRECISION, - DEBUG, - WORKSPACE_SIZE, - MIN_BLOCK_SIZE, - PASS_THROUGH_BUILD_FAILURES, - MAX_AUX_STREAMS, - VERSION_COMPATIBLE, - OPTIMIZATION_LEVEL, - USE_PYTHON_RUNTIME, -) - - -logger = logging.getLogger(__name__) - - -def compile( - gm: torch.nn.Module, - inputs: Any, - *, - device=Device._current_device(), - disable_tf32=False, - sparse_weights=False, - enabled_precisions=set(), - refit=False, - debug=DEBUG, - capability=EngineCapability.default, - num_avg_timing_iters=1, - workspace_size=WORKSPACE_SIZE, - dla_sram_size=1048576, - dla_local_dram_size=1073741824, - dla_global_dram_size=536870912, - calibrator=None, - truncate_long_and_double=False, - require_full_compilation=False, - min_block_size=MIN_BLOCK_SIZE, - torch_executed_ops=[], - torch_executed_modules=[], - pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES, - max_aux_streams=MAX_AUX_STREAMS, - version_compatible=VERSION_COMPATIBLE, - optimization_level=OPTIMIZATION_LEVEL, - use_python_runtime=USE_PYTHON_RUNTIME, - **kwargs, -): - if debug: - logger.setLevel(logging.DEBUG) - - logger.warn( - "The Dynamo backend is an experimental feature, for which only the " - + "following arguments are supported: " - + "{enabled_precisions, debug, workspace_size, min_block_size, " - + "torch_executed_ops, pass_through_build_failures}" - ) - - if not isinstance(inputs, collections.abc.Sequence): - inputs = [inputs] - - inputs = prepare_inputs(inputs, prepare_device(device)) - - if not isinstance(enabled_precisions, collections.abc.Collection): - enabled_precisions = [enabled_precisions] - - # Parse user-specified enabled precisions - if ( - torch.float16 in enabled_precisions - or torch_tensorrt.dtype.half in enabled_precisions - ): - lower_precision = LowerPrecision.FP16 - elif ( - torch.float32 in enabled_precisions - or torch_tensorrt.dtype.float in enabled_precisions - ): - lower_precision = LowerPrecision.FP32 - elif len(enabled_precisions) == 0: - logger.info(f"No precision specified, defaulting to {PRECISION}") - lower_precision = PRECISION - else: - raise ValueError( - f"Precision {enabled_precisions} not supported in the Dynamo Path" - ) - - custom_backend = create_backend( - precision=lower_precision, - debug=debug, - workspace_size=workspace_size, - min_block_size=min_block_size, - torch_executed_ops=torch_executed_ops, - pass_through_build_failures=pass_through_build_failures, - max_aux_streams=max_aux_streams, - version_compatible=version_compatible, - optimization_level=optimization_level, - use_python_runtime=use_python_runtime, - **kwargs, - ) - - model = torch.compile(gm, backend=custom_backend) - - # Ensure compilation occurs by calling the function with provided inputs - model(*inputs) - - return model - - -from torch_tensorrt.fx.utils import LowerPrecision - -logger = logging.getLogger(__name__) - - -def create_backend( - precision: LowerPrecision = PRECISION, - debug: bool = DEBUG, - workspace_size: int = WORKSPACE_SIZE, - min_block_size: int = MIN_BLOCK_SIZE, - torch_executed_ops: Sequence[str] = set(), - pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, - max_aux_streams: Optional[int] = MAX_AUX_STREAMS, - version_compatible: bool = VERSION_COMPATIBLE, - optimization_level: Optional[int] = OPTIMIZATION_LEVEL, - use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME, - **kwargs, -): - """Create torch.compile backend given specified arguments - - Args: - precision: Model Layer precision - debug: Whether to print out verbose debugging information - workspace_size: Workspace TRT is allowed to use for the module (0 is default) - min_block_size: Minimum number of operators per TRT-Engine Block - torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage - pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False) - max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine - version_compatible: Provide version forward-compatibility for engine plan files - optimization_level: Builder optimization 0-5, higher levels imply longer build time, - searching for more optimization options. TRT defaults to 3 - use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime - based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the - argument as None - Returns: - Backend for torch.compile - """ - return partial( - torch_tensorrt_backend, - debug=debug, - precision=precision, - workspace_size=workspace_size, - min_block_size=min_block_size, - torch_executed_ops=torch_executed_ops, - pass_through_build_failures=pass_through_build_failures, - max_aux_streams=max_aux_streams, - version_compatible=version_compatible, - optimization_level=optimization_level, - use_python_runtime=use_python_runtime, - **kwargs, - ) +from .backends import torch_tensorrt_backend +from .compile import compile diff --git a/py/torch_tensorrt/dynamo/compile.py b/py/torch_tensorrt/dynamo/compile.py index c532fa964a..2d0c231f1c 100644 --- a/py/torch_tensorrt/dynamo/compile.py +++ b/py/torch_tensorrt/dynamo/compile.py @@ -17,7 +17,7 @@ ) from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device -from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend +from torch_tensorrt.dynamo.backend import torch_tensorrt_backend from torch_tensorrt.dynamo.conversion import convert_module from torch_tensorrt.dynamo._defaults import ( @@ -58,6 +58,7 @@ def compile( min_block_size=MIN_BLOCK_SIZE, torch_executed_ops=[], torch_executed_modules=[], + pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES, max_aux_streams=MAX_AUX_STREAMS, version_compatible=VERSION_COMPATIBLE, optimization_level=OPTIMIZATION_LEVEL, @@ -97,30 +98,98 @@ def compile( f"Precision {enabled_precisions} not supported in the Dynamo Path" ) - settings = CompilationSettings( + if kwargs.get("ir", "dynamo") == "torch_compile": + custom_backend = create_backend( + precision=lower_precision, + debug=debug, + workspace_size=workspace_size, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, + pass_through_build_failures=pass_through_build_failures, + max_aux_streams=max_aux_streams, + version_compatible=version_compatible, + optimization_level=optimization_level, + use_python_runtime=use_python_runtime, + **kwargs, + ) + model = torch.compile(gm, backend=custom_backend) + # Ensure compilation occurs by calling the function with provided inputs + model(*inputs) + return model + + else: + settings = CompilationSettings( + debug=debug, + precision=lower_precision, + workspace_size=workspace_size, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, + pass_through_build_failures=pass_through_build_failures, + max_aux_streams=max_aux_streams, + version_compatible=version_compatible, + optimization_level=optimization_level, + use_python_runtime=use_python_runtime, + ) + + model = trace(gm, inputs, **kwargs) + + if kwargs.get("use_capability_partitioner", None): + model = lower_model(model, inputs) + return _compile_module(model, inputs, settings) + else: + split_result = lower_model_using_trt_splitter(model, inputs) + trt_module = _compile_graph(split_result, inputs, settings) + + return trt_module + + +def create_backend( + precision: LowerPrecision = PRECISION, + debug: bool = DEBUG, + workspace_size: int = WORKSPACE_SIZE, + min_block_size: int = MIN_BLOCK_SIZE, + torch_executed_ops: Sequence[str] = set(), + pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, + max_aux_streams: Optional[int] = MAX_AUX_STREAMS, + version_compatible: bool = VERSION_COMPATIBLE, + optimization_level: Optional[int] = OPTIMIZATION_LEVEL, + use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME, + **kwargs, +): + """Create torch.compile backend given specified arguments + + Args: + precision: Model Layer precision + debug: Whether to print out verbose debugging information + workspace_size: Workspace TRT is allowed to use for the module (0 is default) + min_block_size: Minimum number of operators per TRT-Engine Block + torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage + pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False) + max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine + version_compatible: Provide version forward-compatibility for engine plan files + optimization_level: Builder optimization 0-5, higher levels imply longer build time, + searching for more optimization options. TRT defaults to 3 + use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime + based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the + argument as None + Returns: + Backend for torch.compile + """ + return partial( + torch_tensorrt_backend, debug=debug, - precision=lower_precision, + precision=precision, workspace_size=workspace_size, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, - pass_through_build_failures=False, + pass_through_build_failures=pass_through_build_failures, max_aux_streams=max_aux_streams, version_compatible=version_compatible, optimization_level=optimization_level, use_python_runtime=use_python_runtime, + **kwargs, ) - model = trace(gm, inputs, **kwargs) - - if kwargs.get("use_capability_partitioner", None): - model = lower_model(model, inputs) - return _compile_module(model, inputs, settings) - else: - split_result = lower_model_using_trt_splitter(model, inputs) - trt_module = _compile_graph(split_result, inputs, settings) - - return trt_module - def _compile_graph( split_result: TRTSplitter, diff --git a/py/torch_tensorrt/dynamo/conversion/conversion.py b/py/torch_tensorrt/dynamo/conversion/conversion.py index 2a3c4a4ebb..4144e922ac 100644 --- a/py/torch_tensorrt/dynamo/conversion/conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/conversion.py @@ -1,7 +1,7 @@ from typing import Sequence, Union import torch import io -from torch_tensorrt.fx.trt_module import TRTModule +from torch_tensorrt.dynamo.runtime import TRTModule from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt import Input from torch_tensorrt.dynamo.conversion import TRTInterpreter @@ -60,7 +60,7 @@ def convert_module( ) else: - from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule + from torch_tensorrt.dynamo.runtime import TorchTensorRTModule with io.BytesIO() as engine_bytes: engine_bytes.write(interpreter_result.engine.serialize()) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTRTModule.py new file mode 100644 index 0000000000..6f401ec991 --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTRTModule.py @@ -0,0 +1,243 @@ +from typing import Any, List, Sequence + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +from torch_tensorrt.fx.utils import unified_dtype_converter, Frameworks + +class TRTModule(torch.nn.Module): + def __init__( + self, engine=None, input_names=None, output_names=None, cuda_graph_batch_size=-1 + ): + super(TRTModule, self).__init__() + self._register_state_dict_hook(TRTModule._on_state_dict) + self.engine = engine + self.input_names = input_names + self.output_names = output_names + self.cuda_graph_batch_size = cuda_graph_batch_size + self.initialized = False + + if engine: + self._initialize() + + def _initialize(self): + self.initialized = True + self.context = self.engine.create_execution_context() + + # Indices of inputs/outputs in the trt engine bindings, in the order + # as they are in the original PyTorch model. + self.input_binding_indices_in_order: Sequence[int] = [ + self.engine.get_binding_index(name) for name in self.input_names + ] + self.output_binding_indices_in_order: Sequence[int] = [ + self.engine.get_binding_index(name) for name in self.output_names + ] + primary_input_outputs = set() + primary_input_outputs.update(self.input_binding_indices_in_order) + primary_input_outputs.update(self.output_binding_indices_in_order) + self.hidden_output_binding_indices_in_order: Sequence[int] = [] + self.hidden_output_names: Sequence[str] = [] + for i in range( + self.engine.num_bindings // self.engine.num_optimization_profiles + ): + if i not in primary_input_outputs: + self.hidden_output_binding_indices_in_order.append(i) + self.hidden_output_names.append(self.engine.get_binding_name(i)) + + assert (self.engine.num_bindings // self.engine.num_optimization_profiles) == ( + len(self.input_names) + + len(self.output_names) + + len(self.hidden_output_names) + ) + + self.input_dtypes: Sequence[torch.dtype] = [ + unified_dtype_converter( + self.engine.get_binding_dtype(idx), Frameworks.TORCH + ) + for idx in self.input_binding_indices_in_order + ] + self.input_shapes: Sequence[Sequence[int]] = [ + tuple(self.engine.get_binding_shape(idx)) + for idx in self.input_binding_indices_in_order + ] + self.output_dtypes: Sequence[torch.dtype] = [ + unified_dtype_converter( + self.engine.get_binding_dtype(idx), Frameworks.TORCH + ) + for idx in self.output_binding_indices_in_order + ] + self.output_shapes = [ + tuple(self.engine.get_binding_shape(idx)) + if self.engine.has_implicit_batch_dimension + else tuple() + for idx in self.output_binding_indices_in_order + ] + self.hidden_output_dtypes: Sequence[torch.dtype] = [ + unified_dtype_converter( + self.engine.get_binding_dtype(idx), Frameworks.TORCH + ) + for idx in self.hidden_output_binding_indices_in_order + ] + self.hidden_output_shapes = [ + tuple(self.engine.get_binding_shape(idx)) + if self.engine.has_implicit_batch_dimension + else tuple() + for idx in self.hidden_output_binding_indices_in_order + ] + + def _check_initialized(self): + if not self.initialized: + raise RuntimeError("TRTModule is not initialized.") + + def _on_state_dict(self, state_dict, prefix, local_metadata): + self._check_initialized() + state_dict[prefix + "engine"] = bytearray(self.engine.serialize()) + state_dict[prefix + "input_names"] = self.input_names + state_dict[prefix + "output_names"] = self.output_names + state_dict[prefix + "cuda_graph_batch_size"] = self.cuda_graph_batch_size + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + engine_bytes = state_dict[prefix + "engine"] + + logger = trt.Logger() + runtime = trt.Runtime(logger) + self.engine = runtime.deserialize_cuda_engine(engine_bytes) + + self.input_names = state_dict[prefix + "input_names"] + self.output_names = state_dict[prefix + "output_names"] + self._initialize() + + def __getstate__(self): + state = self.__dict__.copy() + state["engine"] = bytearray(self.engine.serialize()) + state.pop("context", None) + return state + + def __setstate__(self, state): + logger = trt.Logger() + runtime = trt.Runtime(logger) + state["engine"] = runtime.deserialize_cuda_engine(state["engine"]) + self.__dict__.update(state) + if self.engine: + self.context = self.engine.create_execution_context() + + def forward(self, *inputs): + with torch.autograd.profiler.record_function("TRTModule:Forward"): + self._check_initialized() + + with torch.autograd.profiler.record_function("TRTModule:ProcessInputs"): + assert len(inputs) == len( + self.input_names + ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." + + # This is only used when the trt engine is using implicit batch dim. + batch_size = inputs[0].shape[0] + contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs] + bindings: List[Any] = [None] * ( + len(self.input_names) + + len(self.output_names) + + len(self.hidden_output_names) + ) + + for i, input_name in enumerate(self.input_names): + assert inputs[ + i + ].is_cuda, f"{i}th input({input_name}) is not on cuda device." + assert ( + inputs[i].dtype == self.input_dtypes[i] + ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {inputs[i].dtype}." + + idx = self.input_binding_indices_in_order[i] + bindings[idx] = contiguous_inputs[i].data_ptr() + + if not self.engine.has_implicit_batch_dimension: + self.context.set_binding_shape( + idx, tuple(contiguous_inputs[i].shape) + ) + else: + assert inputs[i].size()[1:] == self.input_shapes[i], ( + f"Shape mismatch for {i}th input({input_name}). " + f"Expect {self.input_shapes[i]}, got {inputs[i].size()[1:]}." + ) + + with torch.autograd.profiler.record_function("TRTModule:ProcessOutputs"): + # create output tensors + outputs: List[torch.Tensor] = [] + + for i, idx in enumerate(self.output_binding_indices_in_order): + if self.engine.has_implicit_batch_dimension: + shape = (batch_size,) + self.output_shapes[i] + else: + shape = tuple(self.context.get_binding_shape(idx)) + + output = torch.empty( # type: ignore[call-overload] + size=shape, + dtype=self.output_dtypes[i], + device=torch.cuda.current_device(), + ) + outputs.append(output) + bindings[idx] = output.data_ptr() + + for i, idx in enumerate(self.hidden_output_binding_indices_in_order): + if self.engine.has_implicit_batch_dimension: + shape = (batch_size,) + self.hidden_output_shapes[i] + else: + shape = tuple(self.context.get_binding_shape(idx)) + + output = torch.empty( # type: ignore[call-overload] + size=shape, + dtype=self.hidden_output_dtypes[i], + device=torch.cuda.current_device(), + ) + bindings[idx] = output.data_ptr() + + with torch.autograd.profiler.record_function("TRTModule:TensorRTRuntime"): + if self.engine.has_implicit_batch_dimension: + self.context.execute_async( + batch_size, bindings, torch.cuda.current_stream().cuda_stream + ) + else: + self.context.execute_async_v2( + bindings, torch.cuda.current_stream().cuda_stream + ) + + if len(outputs) == 1: + return outputs[0] + + return tuple(outputs) + + def enable_profiling(self, profiler: "trt.IProfiler" = None): + """ + Enable TensorRT profiling. After calling this function, TensorRT will report + time spent on each layer in stdout for each forward run. + """ + self._check_initialized() + + if not self.context.profiler: + self.context.profiler = trt.Profiler() if profiler is None else profiler + + def disable_profiling(self): + """ + Disable TensorRT profiling. + """ + self._check_initialized() + + torch.cuda.synchronize() + del self.context + self.context = self.engine.create_execution_context() + + def get_layer_info(self) -> str: + """ + Get layer info of the engine. Only support for TRT > 8.2. + """ + inspector = self.engine.create_engine_inspector() + return inspector.get_engine_information(trt.LayerInformationFormat.JSON) diff --git a/py/torch_tensorrt/dynamo/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py similarity index 100% rename from py/torch_tensorrt/dynamo/_TorchTensorRTModule.py rename to py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py diff --git a/py/torch_tensorrt/dynamo/runtime/__init__.py b/py/torch_tensorrt/dynamo/runtime/__init__.py new file mode 100644 index 0000000000..4f170115b1 --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/__init__.py @@ -0,0 +1,2 @@ +from ._PythonTRTModule import TRTModule +from ._TorchTensorRTModule import TorchTensorRTModule diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index dd309e2bf1..716aec79b9 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -27,7 +27,7 @@ def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool # Runtime was not manually specified by the user, automatically detect runtime else: try: - from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule + from torch_tensorrt.dynamo.runtime import TorchTensorRTModule using_python_runtime = False reason = "since C++ dependency was detected as present"