From dc98f235b3f2f369bb87384c955893f679680947 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 3 Jun 2024 21:55:12 -0700 Subject: [PATCH 01/70] Implemented basic pipeline for Refitting --- py/torch_tensorrt/dynamo/_compiler.py | 124 ++++ .../conversion/_TRTRefittingInterpreter.py | 560 ++++++++++++++++++ .../dynamo/conversion/__init__.py | 2 +- .../dynamo/conversion/_conversion.py | 62 +- .../dynamo/refitting/refit_engine.py | 150 +++++ 5 files changed, 894 insertions(+), 4 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/conversion/_TRTRefittingInterpreter.py create mode 100644 py/torch_tensorrt/dynamo/refitting/refit_engine.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 32b0ca65d7..be9c57a50b 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -22,6 +22,7 @@ CompilationSettings, UnsupportedOperatorException, convert_module, + get_refit_mapping, interpret_module_to_result, repair_double_inputs, ) @@ -609,3 +610,126 @@ def convert_module_to_trt_engine( engine_bytearray = engine_bytes.getvalue() return engine_bytearray + + +def refit_trt_engine_from_module( + exported_program: ExportedProgram, + inputs: Tuple[Any, ...], + engine: object, + *, + enabled_precisions: ( + Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] + ) = _defaults.ENABLED_PRECISIONS, + debug: bool = _defaults.DEBUG, + workspace_size: int = _defaults.WORKSPACE_SIZE, + min_block_size: int = _defaults.MIN_BLOCK_SIZE, + torch_executed_ops: Optional[Set[str]] = None, + pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, + max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, + version_compatible: bool = _defaults.VERSION_COMPATIBLE, + optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, + use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME, + truncate_double: bool = _defaults.TRUNCATE_DOUBLE, + use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, + enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + device: Device = Device._current_device(), + require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, + disable_tf32: bool = _defaults.DISABLE_TF32, + sparse_weights: bool = _defaults.SPARSE_WEIGHTS, + refit: bool = _defaults.REFIT, + engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, + num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, + dla_sram_size: int = _defaults.DLA_SRAM_SIZE, + dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, + dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, + **kwargs: Any, +) -> None: + + if "truncate_long_and_double" in kwargs.keys(): + if truncate_double is not _defaults.TRUNCATE_DOUBLE: + raise ValueError( + 'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"' + ) + else: + truncate_double = kwargs["truncate_long_and_double"] + warnings.warn( + 'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version', + DeprecationWarning, + stacklevel=2, + ) + + input_list = list(inputs) if inputs is not None else [] + torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set() + # Prepare torch_trt inputs + input_list = prepare_inputs(input_list) + device = to_torch_tensorrt_device(device) + + enabled_precisions = {dtype._from(e) for e in enabled_precisions} + + compilation_options = { + "enabled_precisions": enabled_precisions, + "debug": debug, + "workspace_size": workspace_size, + "min_block_size": min_block_size, + "torch_executed_ops": torch_executed_ops, + "pass_through_build_failures": pass_through_build_failures, + "max_aux_streams": max_aux_streams, + "version_compatible": version_compatible, + "optimization_level": optimization_level, + "use_python_runtime": use_python_runtime, + "truncate_double": truncate_double, + "use_fast_partitioner": use_fast_partitioner, + "enable_experimental_decompositions": enable_experimental_decompositions, + "device": device, + "require_full_compilation": require_full_compilation, + "disable_tf32": disable_tf32, + "sparse_weights": sparse_weights, + "refit": refit, + "engine_capability": engine_capability, + "num_avg_timing_iters": num_avg_timing_iters, + "dla_sram_size": dla_sram_size, + "dla_local_dram_size": dla_local_dram_size, + "dla_global_dram_size": dla_global_dram_size, + } + + # Decompose the exported program + exported_program = exported_program.run_decompositions( + get_decompositions(enable_experimental_decompositions) + ) + gm = exported_program.module() + logger.debug("Input graph: " + str(gm.graph)) + + # Apply lowering on the graph module + torch_inputs = get_torch_inputs(input_list, device) + gm = apply_lowering_passes(gm, torch_inputs) + logger.debug("Lowered Input graph: " + str(gm.graph)) + + settings = CompilationSettings(**compilation_options) + logger.info("Compilation Settings: %s\n", settings) + + # Get the refitting mapping + import tensorrt as trt + + mapping = get_refit_mapping(gm, input_list, settings) + + TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + trt_wt_location = trt.TensorLocation.HOST + + refitter = trt.Refitter(engine, TRT_LOGGER) + + weight_list = refitter.get_all_weights() + + for layer_name in weight_list: + if layer_name not in mapping: + print(f"{layer_name} is not found in weight mapping") + + # Use Numpy to create weights + weight = mapping[layer_name] + trt_wt_tensor = trt.Weights(trt.DataType.FLOAT, weight.ctypes.data, weight.size) + refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + + if not refitter.refit_cuda_engine(): + print("Error: failed to refit new weights.") + exit(0) + + print("Refit Successful") diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTRefittingInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTRefittingInterpreter.py new file mode 100644 index 0000000000..624ce25903 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/_TRTRefittingInterpreter.py @@ -0,0 +1,560 @@ +import logging +import warnings +from datetime import datetime +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set + +import numpy as np +import tensorrt as trt +import torch +import torch.fx +from torch.fx.node import _get_qualified_name +from torch.fx.passes.shape_prop import TensorMetadata +from torch.utils._python_dispatch import _disable_current_modes +from torch_tensorrt._enums import dtype +from torch_tensorrt._Input import Input +from torch_tensorrt.dynamo import _defaults +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + DYNAMO_CONVERTERS as CONVERTERS, +) +from torch_tensorrt.dynamo.conversion._ConverterRegistry import CallingConvention +from torch_tensorrt.dynamo.conversion.converter_utils import ( + get_node_name, + get_trt_tensor, +) +from torch_tensorrt.fx.observer import Observer +from torch_tensorrt.logging import TRT_LOGGER + +from packaging import version + +_LOGGER: logging.Logger = logging.getLogger(__name__) + +TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = ( + Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") +) + + +class UnsupportedOperatorException(RuntimeError): + pass + + +class TRTInterpreterResult(NamedTuple): + engine: Any + input_names: Sequence[str] + output_names: Sequence[str] + serialized_cache: bytearray + + +class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc] + def __init__( + self, + module: torch.fx.GraphModule, + input_specs: Sequence[Input], + logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING, + output_dtypes: Optional[Sequence[dtype]] = None, + compilation_settings: CompilationSettings = CompilationSettings(), + ): + super().__init__(module) + + self.logger = TRT_LOGGER + self.builder = trt.Builder(self.logger) + flag = 0 + + # It is deprecated to not use this flag + EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + flag |= EXPLICIT_BATCH + + self.ctx = ConversionContext( + self.builder.create_network(flag), compilation_settings + ) + + assert TRTInterpreter._all_precisions_supported( + compilation_settings.enabled_precisions + ), f"Attempted to enable kernel precisions that are not supported (got: {compilation_settings.enabled_precisions}, support: {_defaults.SUPPORTED_KERNEL_PRECISIONS})" + missing_ops = self.validate_conversion() + if missing_ops: + warnings.warn( + "Interpretation will fail due to missing operations \n" + + "\n".join(f"{i}" for i in missing_ops) + ) + + self.optimization_profiles: Optional[List[trt.IOptimizationProfile]] = ( + [self.builder.create_optimization_profile()] + if any( + input_spec.shape_mode == Input._ShapeMode.DYNAMIC + for input_spec in input_specs + ) + else None + ) + + self.input_specs = input_specs + self.input_specs_iter = 0 + self._cur_node_name: Optional[str] = None + self._cur_node: Optional[torch.fx.Node] = None + self._input_names: List[str] = [] + self._output_names: List[str] = [] + self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = ( + dict() + ) + self.compilation_settings = compilation_settings + + # Data types for TRT Module output Tensors + self.output_dtypes = ( + [dtype._from(o) for o in output_dtypes] if output_dtypes else None + ) + + _LOGGER.debug(f"Graph to be compiled to TensorRT: {self.module.graph}") + + def validate_conversion(self) -> Set[str]: + missing_converters: Set[str] = set() + + for node in self.module.graph.nodes: + if node.op == "call_function" and CONVERTERS.get(node) is None: + missing_converters.add(f"{node.op} {_get_qualified_name(node.target)}") + elif node.op == "call_method" and CONVERTERS.get(node) is None: + missing_converters.add(f"{node.op} torch.Tensor.{node.target}") + elif node.op == "call_module": + submod = self.fetch_attr(node.target) + submod_type = getattr(submod, "_base_class_origin", type(submod)) + if CONVERTERS.get(node) is None: + missing_converters.add(f"{node.op} {torch.typename(submod_type)}") + + return missing_converters + + @staticmethod + def _args_str(args: List[Any]) -> str: + def clean_repr(x: Any, depth: int = 0) -> Any: + if isinstance(x, trt.ITensor): + return f"{x.name} " + elif isinstance(x, torch.Tensor): + return f"" + elif isinstance(x, np.ndarray): + return ( + f"" + ) + elif isinstance(x, Sequence) and not isinstance(x, str): + if depth < 3: + return type(x)([clean_repr(i, depth=depth + 1) for i in x]) # type: ignore[call-arg] + else: + return "(...)" + else: + return x + + str_args = [clean_repr(a) for a in args] + return repr(tuple(str_args)) + + @staticmethod + def _all_precisions_supported(enabled_precisions: Set[dtype]) -> bool: + return enabled_precisions.issubset(_defaults.SUPPORTED_KERNEL_PRECISIONS) + + def validate_compile_settings(self) -> None: + if ( + dtype.i8 in self.compilation_settings.enabled_precisions + and not self.builder.platform_has_fast_int8 + ): + raise RuntimeError("Current platform doesn't support fast native int8!") + + if ( + dtype.f16 in self.compilation_settings.enabled_precisions + and not self.builder.platform_has_fast_fp16 + ): + warnings.warn("Current platform doesn't support fast native fp16!") + + def _populate_trt_builder_config( + self, + strict_type_constraints: bool = False, + algorithm_selector: Optional[trt.IAlgorithmSelector] = None, + tactic_sources: Optional[int] = None, + ) -> trt.IBuilderConfig: + + builder_config = self.builder.create_builder_config() + if self.compilation_settings.workspace_size != 0: + builder_config.set_memory_pool_limit( + trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size + ) + + if version.parse(trt.__version__) >= version.parse("8.2"): + builder_config.profiling_verbosity = ( + trt.ProfilingVerbosity.DETAILED + if self.compilation_settings.debug + else trt.ProfilingVerbosity.LAYER_NAMES_ONLY + ) + + if version.parse(trt.__version__) >= version.parse("8.6"): + if self.compilation_settings.max_aux_streams is not None: + _LOGGER.info( + f"Setting max aux streams to {self.compilation_settings.max_aux_streams}" + ) + builder_config.max_aux_streams = ( + self.compilation_settings.max_aux_streams + ) + if self.compilation_settings.version_compatible: + _LOGGER.info("Using version compatible") + builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) + builder_config.set_flag(trt.BuilderFlag.EXCLUDE_LEAN_RUNTIME) + if self.compilation_settings.hardware_compatible: + _LOGGER.info("Using hardware compatible") + builder_config.hardware_compatibility_level = ( + trt.HardwareCompatibilityLevel.AMPERE_PLUS + ) + if self.compilation_settings.optimization_level is not None: + _LOGGER.info( + f"Using optimization level {self.compilation_settings.optimization_level}" + ) + builder_config.builder_optimization_level = ( + self.compilation_settings.optimization_level + ) + + builder_config.engine_capability = ( + self.compilation_settings.engine_capability.to(trt.EngineCapability) + ) + builder_config.avg_timing_iterations = ( + self.compilation_settings.num_avg_timing_iters + ) + + if self.compilation_settings.device.device_type == trt.DeviceType.DLA: + device_info = torch.cuda.get_device_properties( + self.compilation_settings.device.gpu_id + ) + assert (device_info.major == 8 and device_info.minor == 7) or ( + device_info.major == 7 and device_info.minor == 2 + ), "DLA is not available on non AGX systems" + builder_config.DLA_core = self.compilation_settings.device.dla_core + _LOGGER.info(f"Using DLA core {self.compilation_settings.device.dla_core}") + builder_config.set_memory_pool_limit( + trt.MemoryPoolType.DLA_MANAGED_SRAM, + self.compilation_settings.dla_sram_size, + ) + builder_config.set_memory_pool_limit( + trt.MemoryPoolType.DLA_LOCAL_DRAM, + self.compilation_settings.dla_local_dram_size, + ) + builder_config.set_memory_pool_limit( + trt.MemoryPoolType.DLA_GLOBAL_DRAM, + self.compilation_settings.dla_global_dram_size, + ) + + if dtype.float16 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.FP16) + + if dtype.int8 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.INT8) + + if dtype.bfloat16 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.BF16) + + if self.compilation_settings.sparse_weights: + builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) + + if self.compilation_settings.disable_tf32: + builder_config.clear_flag(trt.BuilderFlag.TF32) + + if self.compilation_settings.refit: + builder_config.set_flag(trt.BuilderFlag.REFIT) + + if strict_type_constraints: + builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) + + if self.optimization_profiles is not None: + if len(self.optimization_profiles) > 0: + for optimization_profile in self.optimization_profiles: + builder_config.add_optimization_profile(optimization_profile) + + if algorithm_selector: + builder_config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE) + builder_config.algorithm_selector = algorithm_selector + + if tactic_sources is not None: + builder_config.set_tactic_sources(tactic_sources=tactic_sources) + + return builder_config + + def _create_timing_cache( + self, + builder_config: trt.IBuilderConfig, + existing_cache: Optional[trt.ITimingCache] = None, + ) -> trt.ITimingCache: + cache = None + if existing_cache: + cache_file = np.array(existing_cache) + cache = builder_config.create_timing_cache(cache_file.tobytes()) + else: + cache = builder_config.create_timing_cache(b"") + builder_config.set_timing_cache(cache, False) + return cache + + def run( + self, + strict_type_constraints: bool = False, + algorithm_selector: Optional[trt.IAlgorithmSelector] = None, + existing_cache: Optional[trt.ITimingCache] = None, + tactic_sources: Optional[int] = None, + ) -> TRTInterpreterResult: + """ + Build TensorRT engine with some configs. + Args: + strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. + algorithm_selector: set up algorithm selection for certain layer + existing_cache: enable timing cache for TensorRT + Return: + TRTInterpreterResult + """ + TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) + + self.input_specs_iter = 0 + run_module_start_time = datetime.now() + super().run() + _LOGGER.info( + f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" + ) + build_engine_start_time = datetime.now() + + builder_config = self._populate_trt_builder_config( + strict_type_constraints, algorithm_selector, tactic_sources + ) + timing_cache = self._create_timing_cache(builder_config, existing_cache) + + serialized_engine = self.builder.build_serialized_network( + self.ctx.net, builder_config + ) + assert serialized_engine + + serialized_cache = ( + bytearray(timing_cache.serialize()) + if builder_config.get_timing_cache() + else bytearray() + ) + _LOGGER.info( + f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" + ) + _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") + + return TRTInterpreterResult( + serialized_engine, self._input_names, self._output_names, serialized_cache + ) + + def get_network_to_refit( + self, + ) -> trt.INetworkDefinition: + """ + Build INetworkDefinition. + Args: + strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. + algorithm_selector: set up algorithm selection for certain layer + existing_cache: enable timing cache for TensorRT + Return: + TRTInterpreterResult + """ + + super().run() + return self.ctx.net + + def run_node(self, n: torch.fx.Node) -> torch.fx.Node: + self._cur_node_name = get_node_name(n) + self._cur_node = n + # add "_itensor_to_tensor_meta" + kwargs = dict(n.kwargs) + kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta + n.kwargs = kwargs + + # run the node + trt_node: torch.fx.Node = super().run_node(n) + + # remove "_itensor_to_tensor_meta" + kwargs = dict(n.kwargs) + del kwargs["_itensor_to_tensor_meta"] + n.kwargs = kwargs + + if isinstance(trt_node, trt.ITensor): + self._itensor_to_tensor_meta[trt_node] = n.meta.get("tensor_meta") + + return trt_node + + def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor: + self._input_names.append(target) + current_input = self.input_specs[self.input_specs_iter] + self.input_specs_iter += 1 + # Set optimization profile for dynamic input shape + shape = None + if current_input.shape_mode == Input._ShapeMode.DYNAMIC: + assert isinstance(current_input.shape, dict) + shape = [] + min_shape = current_input.shape["min_shape"] + opt_shape = current_input.shape["opt_shape"] + max_shape = current_input.shape["max_shape"] + # TODO: Does not support disjoint optimization profiles? + assert self.optimization_profiles is not None + self.optimization_profiles[0].set_shape( + target, min_shape, opt_shape, max_shape + ) + + assert len(min_shape) == len(opt_shape) == len(max_shape) + for i in range(len(min_shape)): + if min_shape[i] == opt_shape[i] == max_shape[i]: + shape.append(min_shape[i]) + else: + # -1 to represent the dynamic dimension + shape.append(-1) + elif current_input.shape_mode == Input._ShapeMode.STATIC: + assert isinstance(current_input.shape, tuple) + shape = list(current_input.shape) + else: + raise RuntimeError( + f"Unable to access shape spec for input: {target} (got: {current_input})" + ) + + trt_input_dtype = current_input.dtype.to(trt.DataType, use_default=True) + _LOGGER.debug( + f"Adding input to in-progress INetwork: {target} [shape={shape}, dtype={trt_input_dtype}]" + ) + return self.ctx.net.add_input( + name=target, + shape=tuple(shape), + dtype=trt_input_dtype, + ) + + def call_module( + self, target: str, args: Any, kwargs: Any + ) -> Any: # Probably should be Tuple[trt.ITensor]? Case for Any? + assert isinstance(target, str) + submod = self.fetch_attr(target) + submod_type = getattr(submod, "_base_class_origin", type(submod)) + converter_packet = CONVERTERS.get(self._cur_node) + + if converter_packet is None: + raise UnsupportedOperatorException( + f"Conversion of module of type {submod_type} not currently supported!" + ) + + converter, calling_convention = converter_packet + + assert self._cur_node_name is not None + _LOGGER.debug( + f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})" + ) + if calling_convention is CallingConvention.LEGACY: + return converter(self.ctx.net, submod, args, kwargs, self._cur_node_name) + else: + return converter(self.ctx, submod, args, kwargs, self._cur_node_name) + + def call_function(self, target: str, args: Any, kwargs: Any) -> Any: + # TODO: Why is this stateful? We should be able to take in the inputs + converter_packet = CONVERTERS.get(self._cur_node) + if converter_packet is None: + raise UnsupportedOperatorException( + f"Conversion of function {torch.typename(target)} not currently supported!" + ) + + converter, calling_convention = converter_packet + + assert self._cur_node_name is not None + _LOGGER.debug( + f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})" + ) + if calling_convention is CallingConvention.LEGACY: + return converter(self.ctx.net, target, args, kwargs, self._cur_node_name) + else: + return converter(self.ctx, target, args, kwargs, self._cur_node_name) + + def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray: + with _disable_current_modes(): + from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy + + frozen_attr = self.fetch_attr(target) + + if isinstance(frozen_attr, torch.nn.Parameter): + constant_tensor = frozen_attr.data + else: + constant_tensor = frozen_attr + + network_constant = to_numpy(constant_tensor) + + return network_constant + + def call_method(self, target: str, args: Any, kwargs: Any) -> Any: + assert isinstance(target, str) + converter_packet = CONVERTERS.get(self._cur_node) + + if converter_packet is None: + raise UnsupportedOperatorException( + f"Conversion of method {target} not currently supported!" + ) + converter, calling_convention = converter_packet + + assert self._cur_node_name is not None + _LOGGER.debug( + f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})" + ) + if calling_convention is CallingConvention.LEGACY: + return converter(self.ctx.net, target, args, kwargs, self._cur_node_name) + else: + return converter(self.ctx, target, args, kwargs, self._cur_node_name) + + def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: + assert len(args) == 1 + if isinstance(args[0], tuple): + outputs = args[0] + elif isinstance(args[0], list): + outputs = tuple(args[0]) + else: + outputs = (args[0],) + + for output_idx in range(len(outputs)): + output = outputs[output_idx] + + if not isinstance(output, trt.ITensor): + new_output = get_trt_tensor(self.ctx, output, target) + outputs = ( + outputs[:output_idx] + (new_output,) + outputs[output_idx + 1 :] + ) + + if not all(isinstance(output, trt.ITensor) for output in outputs): + raise RuntimeError("TensorRT requires all outputs to be Tensor!") + + if self.output_dtypes is not None and len(self.output_dtypes) != len(outputs): + raise RuntimeError( + f"Specified output dtypes ({len(self.output_dtypes)}) differ from number of outputs ({len(outputs)})" + ) + + for i, output in enumerate(outputs): + name = f"output{i}" + + output_dtype = dtype.unknown + if any( + op_name in output.name.split("_") + for op_name in ( + "eq", + "gt", + "lt", + "or", + "xor", + "and", + "not", + "ne", + "isinf", + "isnan", + "any", + ) + ): + output_dtype = dtype.b + elif self.output_dtypes is not None: + if self.output_dtypes[i] == dtype.i64: + output = self.ctx.net.add_cast( + output, dtype.i64.to(trt.DataType) + ).get_output(0) + output_dtype = dtype.i64 + else: + output_dtype = self.output_dtypes[i] + + self.ctx.net.mark_output(output) + if output_dtype is not dtype.unknown: + output.dtype = output_dtype.to(trt.DataType, use_default=True) + output.name = name + + self._output_names.append(name) + _LOGGER.debug( + f"Marking output {name} [shape={output.shape}, dtype={output.dtype}]" + ) + + return list(outputs) diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index 5351f02bb6..c9e4d14486 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -1,5 +1,5 @@ from . import aten_ops_converters, ops_evaluators, prims_ops_converters -from ._conversion import convert_module, interpret_module_to_result +from ._conversion import convert_module, get_refit_mapping, interpret_module_to_result from ._ConversionContext import ConversionContext from ._ConverterRegistry import * # noqa: F403 from ._TRTInterpreter import * # noqa: F403 diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index ea078c7d64..b12228e09d 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -2,23 +2,24 @@ import io import logging +import warnings from typing import List, Sequence +import numpy as np +import tensorrt as trt import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( +from torch_tensorrt.dynamo.conversion._TRTRefittingInterpreter import ( TRTInterpreter, TRTInterpreterResult, ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule from torch_tensorrt.dynamo.utils import get_torch_inputs -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -88,6 +89,61 @@ def interpret_module_to_result( return interpreter_result +def get_refit_mapping( + module: torch.fx.GraphModule, + inputs: Sequence[Input], + settings: CompilationSettings = CompilationSettings(), +) -> dict[str, np.ndarray]: + """Interpret an FX module to a TRTInterpreterResult + Args: + module: FX GraphModule to interpret + inputs: Sequence of Tensors representing inputs to the module + settings: Compilation settings + Returns: + TRTInterpreterResult + """ + output_dtypes = infer_module_output_dtypes( + module, + inputs, + settings.device, + truncate_double=settings.truncate_double, + ) + + # Use Interpreter + module_map = { + "SCALE": (trt.IScaleLayer, [("scale", "SCALE"), ("shift", "SHIFT")]), + "CONVOLUTION": ( + trt.IConvolutionLayer, + [("kernel", "KERNEL"), ("bias", "BIAS")], + ), + "CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]), + } + weight_map = {} + interpreter = TRTInterpreter( + module, + inputs, + logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), + output_dtypes=output_dtypes, + compilation_settings=settings, + ) + + net = interpreter.get_network_to_refit() + for i in range(net.num_layers): + layer = net[i] + layer_type: str = layer.type.name + if layer_type in module_map: + layer.__class__ = module_map[layer_type][0] + for weight_type, weight_name in module_map[layer_type][1]: + weight_map[f"{layer.name} {weight_name}"] = layer.__getattribute__( + weight_type + ).copy() + + else: + warnings.warn(f"{layer_type} is not supported yet") + + return weight_map + + def convert_module( module: torch.fx.GraphModule, inputs: Sequence[Input], diff --git a/py/torch_tensorrt/dynamo/refitting/refit_engine.py b/py/torch_tensorrt/dynamo/refitting/refit_engine.py new file mode 100644 index 0000000000..576f65d3bf --- /dev/null +++ b/py/torch_tensorrt/dynamo/refitting/refit_engine.py @@ -0,0 +1,150 @@ +import numpy as np +import torch +import torchvision.models as models +from torch_tensorrt._Device import Device +from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( + PythonTorchTensorRTModule, +) + +np.random.seed(0) +torch.manual_seed(0) + + +inputs = torch.rand((1, 3, 224, 224)).to("cuda") + + +# class net(nn.Module): +# def __init__(self): +# super().__init__() +# self.conv1 = nn.Conv2d(3, 12, 3, padding=1) +# self.bn = nn.BatchNorm2d(12) +# self.relu = nn.ReLU() + +# def forward(self, x): +# x = self.conv1(x) +# x = self.bn(x) +# x = self.relu(x) +# return x + + +# model = net().eval().to("cuda") +np.random.seed(1) +# model2 = net().eval().to("cuda") +model = models.resnet18(pretrained=False).eval().to("cuda") +model2 = models.resnet18(pretrained=True).eval().to("cuda") +enabled_precisions = {torch.float} +debug = True +workspace_size = 20 << 30 +min_block_size = 1 + + +exp_program = torch.export.export(model, tuple(inputs)) +exp_program2 = torch.export.export(model2, tuple(inputs)) + +from torch_tensorrt.dynamo._compiler import ( + convert_module_to_trt_engine, + refit_trt_engine_from_module, +) + +serialized_engine = convert_module_to_trt_engine( + exported_program=exp_program, + inputs=inputs, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + refit=True, +) + +trt_module = PythonTorchTensorRTModule( + engine=serialized_engine, + input_names=["x"], + output_names=["output0"], + target_device=Device._current_device(), + profiling_enabled=False, +) + +output = trt_module.forward(*inputs) +print(output[0].sum().cpu().item()) +engine = trt_module.engine +print(model(*inputs)[0].sum().cpu().item()) + +# ----------------------Refitting------------------------------------ +weights_to_be_fitted = model2.state_dict() +# refit_dict = { +# '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv2/convolution_1] BIAS': weights_to_be_fitted['conv2.bias'] +# , +# '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv2/convolution_1] KERNEL': weights_to_be_fitted['conv2.weight'] +# , +# '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv1/convolution] BIAS': weights_to_be_fitted['conv1.bias'] +# , +# '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv1/convolution] KERNEL': weights_to_be_fitted['conv1.weight'] +# } + + +# refit_dict = { +# '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv1/convolution] BIAS': weights_to_be_fitted['conv1.bias'] +# , +# '[SCALE]-[aten_ops._native_batch_norm_legit_no_training.default]-[/bn/_native_batch_norm_legit_no_training] SCALE': weights_to_be_fitted['bn.weight'] +# , +# '[SCALE]-[aten_ops._native_batch_norm_legit_no_training.default]-[/bn/_native_batch_norm_legit_no_training] SHIFT': weights_to_be_fitted['bn.bias'] +# , +# '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv1/convolution] KERNEL': weights_to_be_fitted['conv1.weight'] +# } + +# refit_dict = { +# '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv1/convolution] BIAS': exp_program2.module().conv1.state_dict()['bias'] +# , +# '[SCALE]-[aten_ops._native_batch_norm_legit_no_training.default]-[/bn/_native_batch_norm_legit_no_training] SCALE': exp_program2.module().bn.state_dict()['weight'] +# , +# '[SCALE]-[aten_ops._native_batch_norm_legit_no_training.default]-[/bn/_native_batch_norm_legit_no_training] SHIFT': exp_program2.module().bn.state_dict()['bias'] +# , +# '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv1/convolution] KERNEL': exp_program2.module().conv1.state_dict()['weight'] +# } + + +# trt_wt_location = trt.TensorLocation.DEVICE +# TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + +# refitter = trt.Refitter(engine, TRT_LOGGER) + + +# for layer_name in refitter.get_all_weights(): +# v = refit_dict[layer_name] +# trt_wt_tensor = trt.Weights(trt.DataType.FLOAT, v.data_ptr(), torch.numel(v)) +# refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + + +# if not refitter.refit_cuda_engine(): +# print("Error: failed to refit new weights.") +# exit(0) + + +# output = trt_module.forward(*inputs) +# print(output[0].sum().cpu().item()) +# engine = trt_module.engine +# print(model2(*inputs)[0].sum().cpu().item()) +# print() + + +refit_trt_engine_from_module( + exported_program=exp_program2, + inputs=inputs, + engine=engine, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, +) + +output = trt_module.forward(*inputs) +print(output[0].sum().cpu().item()) +engine = trt_module.engine +pytorch_output = model2(*inputs)[0] +print(pytorch_output.sum().cpu().item()) +print((output - pytorch_output).mean()) +print() + + +# Iterate over all layers and print weights +# for layer_name in refitter.get_all_weights(): +# # Print kernel weights +# print_layer_weights(refitter, layer_name) From 74d458ee12d66f96f8c9ffbe45fdbe50dc2c46d2 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 4 Jun 2024 13:34:22 -0700 Subject: [PATCH 02/70] Organized code for refitting --- py/torch_tensorrt/dynamo/_compiler.py | 124 ---- py/torch_tensorrt/dynamo/_refit.py | 200 +++++++ .../dynamo/conversion/_TRTInterpreter.py | 22 +- .../conversion/_TRTRefittingInterpreter.py | 560 ------------------ .../dynamo/conversion/__init__.py | 2 +- .../dynamo/refitting/refit_engine.py | 64 +- 6 files changed, 225 insertions(+), 747 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/_refit.py delete mode 100644 py/torch_tensorrt/dynamo/conversion/_TRTRefittingInterpreter.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index be9c57a50b..32b0ca65d7 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -22,7 +22,6 @@ CompilationSettings, UnsupportedOperatorException, convert_module, - get_refit_mapping, interpret_module_to_result, repair_double_inputs, ) @@ -610,126 +609,3 @@ def convert_module_to_trt_engine( engine_bytearray = engine_bytes.getvalue() return engine_bytearray - - -def refit_trt_engine_from_module( - exported_program: ExportedProgram, - inputs: Tuple[Any, ...], - engine: object, - *, - enabled_precisions: ( - Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] - ) = _defaults.ENABLED_PRECISIONS, - debug: bool = _defaults.DEBUG, - workspace_size: int = _defaults.WORKSPACE_SIZE, - min_block_size: int = _defaults.MIN_BLOCK_SIZE, - torch_executed_ops: Optional[Set[str]] = None, - pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, - max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, - version_compatible: bool = _defaults.VERSION_COMPATIBLE, - optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, - use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME, - truncate_double: bool = _defaults.TRUNCATE_DOUBLE, - use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, - enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, - device: Device = Device._current_device(), - require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, - disable_tf32: bool = _defaults.DISABLE_TF32, - sparse_weights: bool = _defaults.SPARSE_WEIGHTS, - refit: bool = _defaults.REFIT, - engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, - dla_sram_size: int = _defaults.DLA_SRAM_SIZE, - dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, - dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, - **kwargs: Any, -) -> None: - - if "truncate_long_and_double" in kwargs.keys(): - if truncate_double is not _defaults.TRUNCATE_DOUBLE: - raise ValueError( - 'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"' - ) - else: - truncate_double = kwargs["truncate_long_and_double"] - warnings.warn( - 'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version', - DeprecationWarning, - stacklevel=2, - ) - - input_list = list(inputs) if inputs is not None else [] - torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set() - # Prepare torch_trt inputs - input_list = prepare_inputs(input_list) - device = to_torch_tensorrt_device(device) - - enabled_precisions = {dtype._from(e) for e in enabled_precisions} - - compilation_options = { - "enabled_precisions": enabled_precisions, - "debug": debug, - "workspace_size": workspace_size, - "min_block_size": min_block_size, - "torch_executed_ops": torch_executed_ops, - "pass_through_build_failures": pass_through_build_failures, - "max_aux_streams": max_aux_streams, - "version_compatible": version_compatible, - "optimization_level": optimization_level, - "use_python_runtime": use_python_runtime, - "truncate_double": truncate_double, - "use_fast_partitioner": use_fast_partitioner, - "enable_experimental_decompositions": enable_experimental_decompositions, - "device": device, - "require_full_compilation": require_full_compilation, - "disable_tf32": disable_tf32, - "sparse_weights": sparse_weights, - "refit": refit, - "engine_capability": engine_capability, - "num_avg_timing_iters": num_avg_timing_iters, - "dla_sram_size": dla_sram_size, - "dla_local_dram_size": dla_local_dram_size, - "dla_global_dram_size": dla_global_dram_size, - } - - # Decompose the exported program - exported_program = exported_program.run_decompositions( - get_decompositions(enable_experimental_decompositions) - ) - gm = exported_program.module() - logger.debug("Input graph: " + str(gm.graph)) - - # Apply lowering on the graph module - torch_inputs = get_torch_inputs(input_list, device) - gm = apply_lowering_passes(gm, torch_inputs) - logger.debug("Lowered Input graph: " + str(gm.graph)) - - settings = CompilationSettings(**compilation_options) - logger.info("Compilation Settings: %s\n", settings) - - # Get the refitting mapping - import tensorrt as trt - - mapping = get_refit_mapping(gm, input_list, settings) - - TRT_LOGGER = trt.Logger(trt.Logger.ERROR) - trt_wt_location = trt.TensorLocation.HOST - - refitter = trt.Refitter(engine, TRT_LOGGER) - - weight_list = refitter.get_all_weights() - - for layer_name in weight_list: - if layer_name not in mapping: - print(f"{layer_name} is not found in weight mapping") - - # Use Numpy to create weights - weight = mapping[layer_name] - trt_wt_tensor = trt.Weights(trt.DataType.FLOAT, weight.ctypes.data, weight.size) - refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) - - if not refitter.refit_cuda_engine(): - print("Error: failed to refit new weights.") - exit(0) - - print("Refit Successful") diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py new file mode 100644 index 0000000000..511e696804 --- /dev/null +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import logging +import warnings +from typing import Any, Optional, Sequence, Set, Tuple + +import numpy as np +import tensorrt as trt +import torch +from torch.export import ExportedProgram +from torch_tensorrt._Device import Device +from torch_tensorrt._enums import EngineCapability, dtype +from torch_tensorrt._Input import Input +from torch_tensorrt.dynamo import _defaults +from torch_tensorrt.dynamo.conversion import CompilationSettings +from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes +from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter +from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions +from torch_tensorrt.dynamo.utils import ( + get_torch_inputs, + prepare_inputs, + to_torch_tensorrt_device, +) +from torch_tensorrt.logging import TRT_LOGGER + +logger = logging.getLogger(__name__) + + +def get_refit_mapping( + module: torch.fx.GraphModule, + inputs: Sequence[Input], + settings: CompilationSettings = CompilationSettings(), +) -> dict[str, np.ndarray]: + """Interpret an FX module to a TRTInterpreterResult + Args: + module: FX GraphModule to interpret + inputs: Sequence of Tensors representing inputs to the module + settings: Compilation settings + Returns: + TRTInterpreterResult + """ + output_dtypes = infer_module_output_dtypes( + module, + inputs, + settings.device, + truncate_double=settings.truncate_double, + ) + + # Use Interpreter + module_map = { + "SCALE": (trt.IScaleLayer, [("scale", "SCALE"), ("shift", "SHIFT")]), + "CONVOLUTION": ( + trt.IConvolutionLayer, + [("kernel", "KERNEL"), ("bias", "BIAS")], + ), + "CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]), + } + weight_map = {} + interpreter = TRTInterpreter( + module, + inputs, + logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), + output_dtypes=output_dtypes, + compilation_settings=settings, + ) + interpreter._construct_trt_network_def() + net = interpreter.ctx.net + for i in range(net.num_layers): + layer = net[i] + layer_type: str = layer.type.name + if layer_type in module_map: + layer.__class__ = module_map[layer_type][0] + for weight_type, weight_name in module_map[layer_type][1]: + weight_map[f"{layer.name} {weight_name}"] = layer.__getattribute__( + weight_type + ).copy() + + else: + warnings.warn(f"{layer_type} is not supported yet") + + return weight_map + + +def refit_trt_engine_from_module( + exported_program: ExportedProgram, + inputs: Tuple[Any, ...], + engine: object, + *, + enabled_precisions: ( + Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] + ) = _defaults.ENABLED_PRECISIONS, + debug: bool = _defaults.DEBUG, + workspace_size: int = _defaults.WORKSPACE_SIZE, + min_block_size: int = _defaults.MIN_BLOCK_SIZE, + torch_executed_ops: Optional[Set[str]] = None, + pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, + max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, + version_compatible: bool = _defaults.VERSION_COMPATIBLE, + optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, + use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME, + truncate_double: bool = _defaults.TRUNCATE_DOUBLE, + use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, + enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + device: Device = Device._current_device(), + require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, + disable_tf32: bool = _defaults.DISABLE_TF32, + sparse_weights: bool = _defaults.SPARSE_WEIGHTS, + refit: bool = _defaults.REFIT, + engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, + num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, + dla_sram_size: int = _defaults.DLA_SRAM_SIZE, + dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, + dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, + **kwargs: Any, +) -> None: + + if "truncate_long_and_double" in kwargs.keys(): + if truncate_double is not _defaults.TRUNCATE_DOUBLE: + raise ValueError( + 'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"' + ) + else: + truncate_double = kwargs["truncate_long_and_double"] + warnings.warn( + 'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version', + DeprecationWarning, + stacklevel=2, + ) + + input_list = list(inputs) if inputs is not None else [] + torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set() + # Prepare torch_trt inputs + input_list = prepare_inputs(input_list) + device = to_torch_tensorrt_device(device) + + enabled_precisions = {dtype._from(e) for e in enabled_precisions} + + # Try to use the old setting if available + compilation_options = { + "enabled_precisions": enabled_precisions, + "debug": debug, + "workspace_size": workspace_size, + "min_block_size": min_block_size, + "torch_executed_ops": torch_executed_ops, + "pass_through_build_failures": pass_through_build_failures, + "max_aux_streams": max_aux_streams, + "version_compatible": version_compatible, + "optimization_level": optimization_level, + "use_python_runtime": use_python_runtime, + "truncate_double": truncate_double, + "use_fast_partitioner": use_fast_partitioner, + "enable_experimental_decompositions": enable_experimental_decompositions, + "device": device, + "require_full_compilation": require_full_compilation, + "disable_tf32": disable_tf32, + "sparse_weights": sparse_weights, + "refit": refit, + "engine_capability": engine_capability, + "num_avg_timing_iters": num_avg_timing_iters, + "dla_sram_size": dla_sram_size, + "dla_local_dram_size": dla_local_dram_size, + "dla_global_dram_size": dla_global_dram_size, + } + + # Decompose the exported program + exported_program = exported_program.run_decompositions( + get_decompositions(enable_experimental_decompositions) + ) + gm = exported_program.module() + logger.debug("Input graph: " + str(gm.graph)) + + # Apply lowering on the graph module + torch_inputs = get_torch_inputs(input_list, device) + gm = apply_lowering_passes(gm, torch_inputs) + logger.debug("Lowered Input graph: " + str(gm.graph)) + + settings = CompilationSettings(**compilation_options) + logger.info("Compilation Settings: %s\n", settings) + + # Get the refitting mapping + mapping = get_refit_mapping(gm, input_list, settings) + + trt_wt_location = trt.TensorLocation.HOST + refitter = trt.Refitter(engine, TRT_LOGGER) + weight_list = refitter.get_all_weights() + + for layer_name in weight_list: + if layer_name not in mapping: + print(f"{layer_name} is not found in weight mapping") + + # Use Numpy to create weights + weight = mapping[layer_name] + trt_wt_tensor = trt.Weights(trt.DataType.FLOAT, weight.ctypes.data, weight.size) + refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + + if not refitter.refit_cuda_engine(): + print("Error: failed to refit new weights.") + exit(0) + + print("Refit Successful") diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 422842f644..d3a2c00e13 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -285,6 +285,19 @@ def _create_timing_cache( builder_config.set_timing_cache(cache, False) return cache + def _construct_trt_network_def(self) -> None: + """ + Run the interpreter on each node to get TRT INetwork + """ + TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) + + self.input_specs_iter = 0 + run_module_start_time = datetime.now() + super().run() + _LOGGER.info( + f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" + ) + def run( self, strict_type_constraints: bool = False, @@ -301,14 +314,7 @@ def run( Return: TRTInterpreterResult """ - TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) - - self.input_specs_iter = 0 - run_module_start_time = datetime.now() - super().run() - _LOGGER.info( - f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" - ) + self._construct_trt_network_def() build_engine_start_time = datetime.now() builder_config = self._populate_trt_builder_config( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTRefittingInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTRefittingInterpreter.py deleted file mode 100644 index 624ce25903..0000000000 --- a/py/torch_tensorrt/dynamo/conversion/_TRTRefittingInterpreter.py +++ /dev/null @@ -1,560 +0,0 @@ -import logging -import warnings -from datetime import datetime -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set - -import numpy as np -import tensorrt as trt -import torch -import torch.fx -from torch.fx.node import _get_qualified_name -from torch.fx.passes.shape_prop import TensorMetadata -from torch.utils._python_dispatch import _disable_current_modes -from torch_tensorrt._enums import dtype -from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo import _defaults -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( - DYNAMO_CONVERTERS as CONVERTERS, -) -from torch_tensorrt.dynamo.conversion._ConverterRegistry import CallingConvention -from torch_tensorrt.dynamo.conversion.converter_utils import ( - get_node_name, - get_trt_tensor, -) -from torch_tensorrt.fx.observer import Observer -from torch_tensorrt.logging import TRT_LOGGER - -from packaging import version - -_LOGGER: logging.Logger = logging.getLogger(__name__) - -TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = ( - Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") -) - - -class UnsupportedOperatorException(RuntimeError): - pass - - -class TRTInterpreterResult(NamedTuple): - engine: Any - input_names: Sequence[str] - output_names: Sequence[str] - serialized_cache: bytearray - - -class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc] - def __init__( - self, - module: torch.fx.GraphModule, - input_specs: Sequence[Input], - logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING, - output_dtypes: Optional[Sequence[dtype]] = None, - compilation_settings: CompilationSettings = CompilationSettings(), - ): - super().__init__(module) - - self.logger = TRT_LOGGER - self.builder = trt.Builder(self.logger) - flag = 0 - - # It is deprecated to not use this flag - EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) - flag |= EXPLICIT_BATCH - - self.ctx = ConversionContext( - self.builder.create_network(flag), compilation_settings - ) - - assert TRTInterpreter._all_precisions_supported( - compilation_settings.enabled_precisions - ), f"Attempted to enable kernel precisions that are not supported (got: {compilation_settings.enabled_precisions}, support: {_defaults.SUPPORTED_KERNEL_PRECISIONS})" - missing_ops = self.validate_conversion() - if missing_ops: - warnings.warn( - "Interpretation will fail due to missing operations \n" - + "\n".join(f"{i}" for i in missing_ops) - ) - - self.optimization_profiles: Optional[List[trt.IOptimizationProfile]] = ( - [self.builder.create_optimization_profile()] - if any( - input_spec.shape_mode == Input._ShapeMode.DYNAMIC - for input_spec in input_specs - ) - else None - ) - - self.input_specs = input_specs - self.input_specs_iter = 0 - self._cur_node_name: Optional[str] = None - self._cur_node: Optional[torch.fx.Node] = None - self._input_names: List[str] = [] - self._output_names: List[str] = [] - self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = ( - dict() - ) - self.compilation_settings = compilation_settings - - # Data types for TRT Module output Tensors - self.output_dtypes = ( - [dtype._from(o) for o in output_dtypes] if output_dtypes else None - ) - - _LOGGER.debug(f"Graph to be compiled to TensorRT: {self.module.graph}") - - def validate_conversion(self) -> Set[str]: - missing_converters: Set[str] = set() - - for node in self.module.graph.nodes: - if node.op == "call_function" and CONVERTERS.get(node) is None: - missing_converters.add(f"{node.op} {_get_qualified_name(node.target)}") - elif node.op == "call_method" and CONVERTERS.get(node) is None: - missing_converters.add(f"{node.op} torch.Tensor.{node.target}") - elif node.op == "call_module": - submod = self.fetch_attr(node.target) - submod_type = getattr(submod, "_base_class_origin", type(submod)) - if CONVERTERS.get(node) is None: - missing_converters.add(f"{node.op} {torch.typename(submod_type)}") - - return missing_converters - - @staticmethod - def _args_str(args: List[Any]) -> str: - def clean_repr(x: Any, depth: int = 0) -> Any: - if isinstance(x, trt.ITensor): - return f"{x.name} " - elif isinstance(x, torch.Tensor): - return f"" - elif isinstance(x, np.ndarray): - return ( - f"" - ) - elif isinstance(x, Sequence) and not isinstance(x, str): - if depth < 3: - return type(x)([clean_repr(i, depth=depth + 1) for i in x]) # type: ignore[call-arg] - else: - return "(...)" - else: - return x - - str_args = [clean_repr(a) for a in args] - return repr(tuple(str_args)) - - @staticmethod - def _all_precisions_supported(enabled_precisions: Set[dtype]) -> bool: - return enabled_precisions.issubset(_defaults.SUPPORTED_KERNEL_PRECISIONS) - - def validate_compile_settings(self) -> None: - if ( - dtype.i8 in self.compilation_settings.enabled_precisions - and not self.builder.platform_has_fast_int8 - ): - raise RuntimeError("Current platform doesn't support fast native int8!") - - if ( - dtype.f16 in self.compilation_settings.enabled_precisions - and not self.builder.platform_has_fast_fp16 - ): - warnings.warn("Current platform doesn't support fast native fp16!") - - def _populate_trt_builder_config( - self, - strict_type_constraints: bool = False, - algorithm_selector: Optional[trt.IAlgorithmSelector] = None, - tactic_sources: Optional[int] = None, - ) -> trt.IBuilderConfig: - - builder_config = self.builder.create_builder_config() - if self.compilation_settings.workspace_size != 0: - builder_config.set_memory_pool_limit( - trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size - ) - - if version.parse(trt.__version__) >= version.parse("8.2"): - builder_config.profiling_verbosity = ( - trt.ProfilingVerbosity.DETAILED - if self.compilation_settings.debug - else trt.ProfilingVerbosity.LAYER_NAMES_ONLY - ) - - if version.parse(trt.__version__) >= version.parse("8.6"): - if self.compilation_settings.max_aux_streams is not None: - _LOGGER.info( - f"Setting max aux streams to {self.compilation_settings.max_aux_streams}" - ) - builder_config.max_aux_streams = ( - self.compilation_settings.max_aux_streams - ) - if self.compilation_settings.version_compatible: - _LOGGER.info("Using version compatible") - builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) - builder_config.set_flag(trt.BuilderFlag.EXCLUDE_LEAN_RUNTIME) - if self.compilation_settings.hardware_compatible: - _LOGGER.info("Using hardware compatible") - builder_config.hardware_compatibility_level = ( - trt.HardwareCompatibilityLevel.AMPERE_PLUS - ) - if self.compilation_settings.optimization_level is not None: - _LOGGER.info( - f"Using optimization level {self.compilation_settings.optimization_level}" - ) - builder_config.builder_optimization_level = ( - self.compilation_settings.optimization_level - ) - - builder_config.engine_capability = ( - self.compilation_settings.engine_capability.to(trt.EngineCapability) - ) - builder_config.avg_timing_iterations = ( - self.compilation_settings.num_avg_timing_iters - ) - - if self.compilation_settings.device.device_type == trt.DeviceType.DLA: - device_info = torch.cuda.get_device_properties( - self.compilation_settings.device.gpu_id - ) - assert (device_info.major == 8 and device_info.minor == 7) or ( - device_info.major == 7 and device_info.minor == 2 - ), "DLA is not available on non AGX systems" - builder_config.DLA_core = self.compilation_settings.device.dla_core - _LOGGER.info(f"Using DLA core {self.compilation_settings.device.dla_core}") - builder_config.set_memory_pool_limit( - trt.MemoryPoolType.DLA_MANAGED_SRAM, - self.compilation_settings.dla_sram_size, - ) - builder_config.set_memory_pool_limit( - trt.MemoryPoolType.DLA_LOCAL_DRAM, - self.compilation_settings.dla_local_dram_size, - ) - builder_config.set_memory_pool_limit( - trt.MemoryPoolType.DLA_GLOBAL_DRAM, - self.compilation_settings.dla_global_dram_size, - ) - - if dtype.float16 in self.compilation_settings.enabled_precisions: - builder_config.set_flag(trt.BuilderFlag.FP16) - - if dtype.int8 in self.compilation_settings.enabled_precisions: - builder_config.set_flag(trt.BuilderFlag.INT8) - - if dtype.bfloat16 in self.compilation_settings.enabled_precisions: - builder_config.set_flag(trt.BuilderFlag.BF16) - - if self.compilation_settings.sparse_weights: - builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) - - if self.compilation_settings.disable_tf32: - builder_config.clear_flag(trt.BuilderFlag.TF32) - - if self.compilation_settings.refit: - builder_config.set_flag(trt.BuilderFlag.REFIT) - - if strict_type_constraints: - builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) - - if self.optimization_profiles is not None: - if len(self.optimization_profiles) > 0: - for optimization_profile in self.optimization_profiles: - builder_config.add_optimization_profile(optimization_profile) - - if algorithm_selector: - builder_config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE) - builder_config.algorithm_selector = algorithm_selector - - if tactic_sources is not None: - builder_config.set_tactic_sources(tactic_sources=tactic_sources) - - return builder_config - - def _create_timing_cache( - self, - builder_config: trt.IBuilderConfig, - existing_cache: Optional[trt.ITimingCache] = None, - ) -> trt.ITimingCache: - cache = None - if existing_cache: - cache_file = np.array(existing_cache) - cache = builder_config.create_timing_cache(cache_file.tobytes()) - else: - cache = builder_config.create_timing_cache(b"") - builder_config.set_timing_cache(cache, False) - return cache - - def run( - self, - strict_type_constraints: bool = False, - algorithm_selector: Optional[trt.IAlgorithmSelector] = None, - existing_cache: Optional[trt.ITimingCache] = None, - tactic_sources: Optional[int] = None, - ) -> TRTInterpreterResult: - """ - Build TensorRT engine with some configs. - Args: - strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. - algorithm_selector: set up algorithm selection for certain layer - existing_cache: enable timing cache for TensorRT - Return: - TRTInterpreterResult - """ - TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) - - self.input_specs_iter = 0 - run_module_start_time = datetime.now() - super().run() - _LOGGER.info( - f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" - ) - build_engine_start_time = datetime.now() - - builder_config = self._populate_trt_builder_config( - strict_type_constraints, algorithm_selector, tactic_sources - ) - timing_cache = self._create_timing_cache(builder_config, existing_cache) - - serialized_engine = self.builder.build_serialized_network( - self.ctx.net, builder_config - ) - assert serialized_engine - - serialized_cache = ( - bytearray(timing_cache.serialize()) - if builder_config.get_timing_cache() - else bytearray() - ) - _LOGGER.info( - f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" - ) - _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") - - return TRTInterpreterResult( - serialized_engine, self._input_names, self._output_names, serialized_cache - ) - - def get_network_to_refit( - self, - ) -> trt.INetworkDefinition: - """ - Build INetworkDefinition. - Args: - strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. - algorithm_selector: set up algorithm selection for certain layer - existing_cache: enable timing cache for TensorRT - Return: - TRTInterpreterResult - """ - - super().run() - return self.ctx.net - - def run_node(self, n: torch.fx.Node) -> torch.fx.Node: - self._cur_node_name = get_node_name(n) - self._cur_node = n - # add "_itensor_to_tensor_meta" - kwargs = dict(n.kwargs) - kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta - n.kwargs = kwargs - - # run the node - trt_node: torch.fx.Node = super().run_node(n) - - # remove "_itensor_to_tensor_meta" - kwargs = dict(n.kwargs) - del kwargs["_itensor_to_tensor_meta"] - n.kwargs = kwargs - - if isinstance(trt_node, trt.ITensor): - self._itensor_to_tensor_meta[trt_node] = n.meta.get("tensor_meta") - - return trt_node - - def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor: - self._input_names.append(target) - current_input = self.input_specs[self.input_specs_iter] - self.input_specs_iter += 1 - # Set optimization profile for dynamic input shape - shape = None - if current_input.shape_mode == Input._ShapeMode.DYNAMIC: - assert isinstance(current_input.shape, dict) - shape = [] - min_shape = current_input.shape["min_shape"] - opt_shape = current_input.shape["opt_shape"] - max_shape = current_input.shape["max_shape"] - # TODO: Does not support disjoint optimization profiles? - assert self.optimization_profiles is not None - self.optimization_profiles[0].set_shape( - target, min_shape, opt_shape, max_shape - ) - - assert len(min_shape) == len(opt_shape) == len(max_shape) - for i in range(len(min_shape)): - if min_shape[i] == opt_shape[i] == max_shape[i]: - shape.append(min_shape[i]) - else: - # -1 to represent the dynamic dimension - shape.append(-1) - elif current_input.shape_mode == Input._ShapeMode.STATIC: - assert isinstance(current_input.shape, tuple) - shape = list(current_input.shape) - else: - raise RuntimeError( - f"Unable to access shape spec for input: {target} (got: {current_input})" - ) - - trt_input_dtype = current_input.dtype.to(trt.DataType, use_default=True) - _LOGGER.debug( - f"Adding input to in-progress INetwork: {target} [shape={shape}, dtype={trt_input_dtype}]" - ) - return self.ctx.net.add_input( - name=target, - shape=tuple(shape), - dtype=trt_input_dtype, - ) - - def call_module( - self, target: str, args: Any, kwargs: Any - ) -> Any: # Probably should be Tuple[trt.ITensor]? Case for Any? - assert isinstance(target, str) - submod = self.fetch_attr(target) - submod_type = getattr(submod, "_base_class_origin", type(submod)) - converter_packet = CONVERTERS.get(self._cur_node) - - if converter_packet is None: - raise UnsupportedOperatorException( - f"Conversion of module of type {submod_type} not currently supported!" - ) - - converter, calling_convention = converter_packet - - assert self._cur_node_name is not None - _LOGGER.debug( - f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})" - ) - if calling_convention is CallingConvention.LEGACY: - return converter(self.ctx.net, submod, args, kwargs, self._cur_node_name) - else: - return converter(self.ctx, submod, args, kwargs, self._cur_node_name) - - def call_function(self, target: str, args: Any, kwargs: Any) -> Any: - # TODO: Why is this stateful? We should be able to take in the inputs - converter_packet = CONVERTERS.get(self._cur_node) - if converter_packet is None: - raise UnsupportedOperatorException( - f"Conversion of function {torch.typename(target)} not currently supported!" - ) - - converter, calling_convention = converter_packet - - assert self._cur_node_name is not None - _LOGGER.debug( - f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})" - ) - if calling_convention is CallingConvention.LEGACY: - return converter(self.ctx.net, target, args, kwargs, self._cur_node_name) - else: - return converter(self.ctx, target, args, kwargs, self._cur_node_name) - - def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray: - with _disable_current_modes(): - from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy - - frozen_attr = self.fetch_attr(target) - - if isinstance(frozen_attr, torch.nn.Parameter): - constant_tensor = frozen_attr.data - else: - constant_tensor = frozen_attr - - network_constant = to_numpy(constant_tensor) - - return network_constant - - def call_method(self, target: str, args: Any, kwargs: Any) -> Any: - assert isinstance(target, str) - converter_packet = CONVERTERS.get(self._cur_node) - - if converter_packet is None: - raise UnsupportedOperatorException( - f"Conversion of method {target} not currently supported!" - ) - converter, calling_convention = converter_packet - - assert self._cur_node_name is not None - _LOGGER.debug( - f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})" - ) - if calling_convention is CallingConvention.LEGACY: - return converter(self.ctx.net, target, args, kwargs, self._cur_node_name) - else: - return converter(self.ctx, target, args, kwargs, self._cur_node_name) - - def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: - assert len(args) == 1 - if isinstance(args[0], tuple): - outputs = args[0] - elif isinstance(args[0], list): - outputs = tuple(args[0]) - else: - outputs = (args[0],) - - for output_idx in range(len(outputs)): - output = outputs[output_idx] - - if not isinstance(output, trt.ITensor): - new_output = get_trt_tensor(self.ctx, output, target) - outputs = ( - outputs[:output_idx] + (new_output,) + outputs[output_idx + 1 :] - ) - - if not all(isinstance(output, trt.ITensor) for output in outputs): - raise RuntimeError("TensorRT requires all outputs to be Tensor!") - - if self.output_dtypes is not None and len(self.output_dtypes) != len(outputs): - raise RuntimeError( - f"Specified output dtypes ({len(self.output_dtypes)}) differ from number of outputs ({len(outputs)})" - ) - - for i, output in enumerate(outputs): - name = f"output{i}" - - output_dtype = dtype.unknown - if any( - op_name in output.name.split("_") - for op_name in ( - "eq", - "gt", - "lt", - "or", - "xor", - "and", - "not", - "ne", - "isinf", - "isnan", - "any", - ) - ): - output_dtype = dtype.b - elif self.output_dtypes is not None: - if self.output_dtypes[i] == dtype.i64: - output = self.ctx.net.add_cast( - output, dtype.i64.to(trt.DataType) - ).get_output(0) - output_dtype = dtype.i64 - else: - output_dtype = self.output_dtypes[i] - - self.ctx.net.mark_output(output) - if output_dtype is not dtype.unknown: - output.dtype = output_dtype.to(trt.DataType, use_default=True) - output.name = name - - self._output_names.append(name) - _LOGGER.debug( - f"Marking output {name} [shape={output.shape}, dtype={output.dtype}]" - ) - - return list(outputs) diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index c9e4d14486..5351f02bb6 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -1,5 +1,5 @@ from . import aten_ops_converters, ops_evaluators, prims_ops_converters -from ._conversion import convert_module, get_refit_mapping, interpret_module_to_result +from ._conversion import convert_module, interpret_module_to_result from ._ConversionContext import ConversionContext from ._ConverterRegistry import * # noqa: F403 from ._TRTInterpreter import * # noqa: F403 diff --git a/py/torch_tensorrt/dynamo/refitting/refit_engine.py b/py/torch_tensorrt/dynamo/refitting/refit_engine.py index 576f65d3bf..1df82e7e28 100644 --- a/py/torch_tensorrt/dynamo/refitting/refit_engine.py +++ b/py/torch_tensorrt/dynamo/refitting/refit_engine.py @@ -2,6 +2,8 @@ import torch import torchvision.models as models from torch_tensorrt._Device import Device +from torch_tensorrt.dynamo._compiler import convert_module_to_trt_engine +from torch_tensorrt.dynamo._refit import refit_trt_engine_from_module from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( PythonTorchTensorRTModule, ) @@ -10,7 +12,7 @@ torch.manual_seed(0) -inputs = torch.rand((1, 3, 224, 224)).to("cuda") +inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] # class net(nn.Module): @@ -28,8 +30,9 @@ # model = net().eval().to("cuda") -np.random.seed(1) +# np.random.seed(1) # model2 = net().eval().to("cuda") + model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") enabled_precisions = {torch.float} @@ -41,14 +44,10 @@ exp_program = torch.export.export(model, tuple(inputs)) exp_program2 = torch.export.export(model2, tuple(inputs)) -from torch_tensorrt.dynamo._compiler import ( - convert_module_to_trt_engine, - refit_trt_engine_from_module, -) serialized_engine = convert_module_to_trt_engine( exported_program=exp_program, - inputs=inputs, + inputs=tuple(inputs), enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, @@ -69,16 +68,7 @@ print(model(*inputs)[0].sum().cpu().item()) # ----------------------Refitting------------------------------------ -weights_to_be_fitted = model2.state_dict() -# refit_dict = { -# '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv2/convolution_1] BIAS': weights_to_be_fitted['conv2.bias'] -# , -# '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv2/convolution_1] KERNEL': weights_to_be_fitted['conv2.weight'] -# , -# '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv1/convolution] BIAS': weights_to_be_fitted['conv1.bias'] -# , -# '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv1/convolution] KERNEL': weights_to_be_fitted['conv1.weight'] -# } +# weights_to_be_fitted = model2.state_dict() # refit_dict = { @@ -91,45 +81,11 @@ # '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv1/convolution] KERNEL': weights_to_be_fitted['conv1.weight'] # } -# refit_dict = { -# '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv1/convolution] BIAS': exp_program2.module().conv1.state_dict()['bias'] -# , -# '[SCALE]-[aten_ops._native_batch_norm_legit_no_training.default]-[/bn/_native_batch_norm_legit_no_training] SCALE': exp_program2.module().bn.state_dict()['weight'] -# , -# '[SCALE]-[aten_ops._native_batch_norm_legit_no_training.default]-[/bn/_native_batch_norm_legit_no_training] SHIFT': exp_program2.module().bn.state_dict()['bias'] -# , -# '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv1/convolution] KERNEL': exp_program2.module().conv1.state_dict()['weight'] -# } - - -# trt_wt_location = trt.TensorLocation.DEVICE -# TRT_LOGGER = trt.Logger(trt.Logger.ERROR) - -# refitter = trt.Refitter(engine, TRT_LOGGER) - - -# for layer_name in refitter.get_all_weights(): -# v = refit_dict[layer_name] -# trt_wt_tensor = trt.Weights(trt.DataType.FLOAT, v.data_ptr(), torch.numel(v)) -# refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) - - -# if not refitter.refit_cuda_engine(): -# print("Error: failed to refit new weights.") -# exit(0) - - -# output = trt_module.forward(*inputs) -# print(output[0].sum().cpu().item()) -# engine = trt_module.engine -# print(model2(*inputs)[0].sum().cpu().item()) -# print() - refit_trt_engine_from_module( - exported_program=exp_program2, - inputs=inputs, - engine=engine, + exported_program=exp_program2, # New + inputs=tuple(inputs), + engine=engine, # Old enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, From c47bef34263ddd942d52f339f061c176af1ca30f Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 4 Jun 2024 13:38:53 -0700 Subject: [PATCH 03/70] Renamed function --- py/torch_tensorrt/dynamo/_refit.py | 6 +++--- py/torch_tensorrt/dynamo/refitting/refit_engine.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 511e696804..bfa70621e2 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) -def get_refit_mapping( +def construct_refit_mapping( module: torch.fx.GraphModule, inputs: Sequence[Input], settings: CompilationSettings = CompilationSettings(), @@ -81,7 +81,7 @@ def get_refit_mapping( return weight_map -def refit_trt_engine_from_module( +def refit_single_trt_engine_with_ep( exported_program: ExportedProgram, inputs: Tuple[Any, ...], engine: object, @@ -178,7 +178,7 @@ def refit_trt_engine_from_module( logger.info("Compilation Settings: %s\n", settings) # Get the refitting mapping - mapping = get_refit_mapping(gm, input_list, settings) + mapping = construct_refit_mapping(gm, input_list, settings) trt_wt_location = trt.TensorLocation.HOST refitter = trt.Refitter(engine, TRT_LOGGER) diff --git a/py/torch_tensorrt/dynamo/refitting/refit_engine.py b/py/torch_tensorrt/dynamo/refitting/refit_engine.py index 1df82e7e28..b3b42b6e5b 100644 --- a/py/torch_tensorrt/dynamo/refitting/refit_engine.py +++ b/py/torch_tensorrt/dynamo/refitting/refit_engine.py @@ -3,7 +3,7 @@ import torchvision.models as models from torch_tensorrt._Device import Device from torch_tensorrt.dynamo._compiler import convert_module_to_trt_engine -from torch_tensorrt.dynamo._refit import refit_trt_engine_from_module +from torch_tensorrt.dynamo._refit import refit_single_trt_engine_with_ep from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( PythonTorchTensorRTModule, ) @@ -82,7 +82,7 @@ # } -refit_trt_engine_from_module( +refit_single_trt_engine_with_ep( exported_program=exp_program2, # New inputs=tuple(inputs), engine=engine, # Old From 869aaad93bdc491cee51e67cff92d7cb2f352733 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 4 Jun 2024 17:02:02 -0700 Subject: [PATCH 04/70] Supported multi-engine --- py/torch_tensorrt/dynamo/_refit.py | 251 +++++++++++++++++++++-------- 1 file changed, 184 insertions(+), 67 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index bfa70621e2..3da068f256 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -1,24 +1,32 @@ from __future__ import annotations +import collections.abc import logging import warnings -from typing import Any, Optional, Sequence, Set, Tuple +from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union import numpy as np import tensorrt as trt import torch from torch.export import ExportedProgram +from torch.fx.node import Target from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo import _defaults +from torch_tensorrt.dynamo import _defaults, partitioning from torch_tensorrt.dynamo.conversion import CompilationSettings from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + DYNAMO_CONVERTERS as CONVERTERS, +) from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter +from torch_tensorrt.dynamo.conversion.truncate_double import repair_double_inputs from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions from torch_tensorrt.dynamo.utils import ( get_torch_inputs, prepare_inputs, + set_log_level, + to_torch_device, to_torch_tensorrt_device, ) from torch_tensorrt.logging import TRT_LOGGER @@ -39,6 +47,15 @@ def construct_refit_mapping( Returns: TRTInterpreterResult """ + MODULE_MAP = { + "SCALE": (trt.IScaleLayer, [("scale", "SCALE"), ("shift", "SHIFT")]), + "CONVOLUTION": ( + trt.IConvolutionLayer, + [("kernel", "KERNEL"), ("bias", "BIAS")], + ), + "CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]), + } + output_dtypes = infer_module_output_dtypes( module, inputs, @@ -47,14 +64,6 @@ def construct_refit_mapping( ) # Use Interpreter - module_map = { - "SCALE": (trt.IScaleLayer, [("scale", "SCALE"), ("shift", "SHIFT")]), - "CONVOLUTION": ( - trt.IConvolutionLayer, - [("kernel", "KERNEL"), ("bias", "BIAS")], - ), - "CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]), - } weight_map = {} interpreter = TRTInterpreter( module, @@ -68,9 +77,9 @@ def construct_refit_mapping( for i in range(net.num_layers): layer = net[i] layer_type: str = layer.type.name - if layer_type in module_map: - layer.__class__ = module_map[layer_type][0] - for weight_type, weight_name in module_map[layer_type][1]: + if layer_type in MODULE_MAP: + layer.__class__ = MODULE_MAP[layer_type][0] + for weight_type, weight_name in MODULE_MAP[layer_type][1]: weight_map[f"{layer.name} {weight_name}"] = layer.__getattribute__( weight_type ).copy() @@ -81,38 +90,87 @@ def construct_refit_mapping( return weight_map -def refit_single_trt_engine_with_ep( - exported_program: ExportedProgram, +def refit_single_trt_engine_with_gm( + new_gm: torch.fx.GraphModule, + old_engine: trt.ICudaEngine, + input_list: Tuple[Any, ...], + settings: CompilationSettings = CompilationSettings(), +) -> None: + + # Get the refitting mapping + mapping = construct_refit_mapping(new_gm, input_list, settings) + + trt_wt_location = trt.TensorLocation.HOST + refitter = trt.Refitter(old_engine, TRT_LOGGER) + weight_list = refitter.get_all_weights() + + for layer_name in weight_list: + if layer_name not in mapping: + print(f"{layer_name} is not found in weight mapping") + + # Use Numpy to create weights + weight = mapping[layer_name] + trt_wt_tensor = trt.Weights(trt.DataType.FLOAT, weight.ctypes.data, weight.size) + refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + + if not refitter.refit_cuda_engine(): + print("Error: failed to refit new weights.") + exit(0) + + print("Refit Successful") + + +# def refit_module_weights( +# compiled_module: ExportedProgram, +# new_weight_module: ExportedProgram +# ) -> torch.fx.GraphModule: +# pass + + +def refit_module_weights( + compiled_module: torch.fx.GraphModule | ExportedProgram, + new_weight_module: ExportedProgram, inputs: Tuple[Any, ...], - engine: object, *, + device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE, + disable_tf32: bool = _defaults.DISABLE_TF32, + sparse_weights: bool = _defaults.SPARSE_WEIGHTS, enabled_precisions: ( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] ) = _defaults.ENABLED_PRECISIONS, + engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, + refit: bool = _defaults.REFIT, debug: bool = _defaults.DEBUG, + num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, + dla_sram_size: int = _defaults.DLA_SRAM_SIZE, + dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, + dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, + truncate_double: bool = _defaults.TRUNCATE_DOUBLE, + require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, min_block_size: int = _defaults.MIN_BLOCK_SIZE, - torch_executed_ops: Optional[Set[str]] = None, + torch_executed_ops: Optional[Collection[Target]] = None, + torch_executed_modules: Optional[List[str]] = None, pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, version_compatible: bool = _defaults.VERSION_COMPATIBLE, optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, - use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME, - truncate_double: bool = _defaults.TRUNCATE_DOUBLE, + use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME, use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, - device: Device = Device._current_device(), - require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, - disable_tf32: bool = _defaults.DISABLE_TF32, - sparse_weights: bool = _defaults.SPARSE_WEIGHTS, - refit: bool = _defaults.REFIT, - engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, - dla_sram_size: int = _defaults.DLA_SRAM_SIZE, - dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, - dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, + dryrun: bool = _defaults.DRYRUN, + hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, **kwargs: Any, -) -> None: +) -> torch.fx.GraphModule: + """ + Refit a compiled graph module with ExportedProgram + """ + + if debug: + set_log_level(logger.parent, logging.DEBUG) + + if type(compiled_module) == ExportedProgram: + compiled_module = compiled_module.module() if "truncate_long_and_double" in kwargs.keys(): if truncate_double is not _defaults.TRUNCATE_DOUBLE: @@ -127,21 +185,48 @@ def refit_single_trt_engine_with_ep( stacklevel=2, ) - input_list = list(inputs) if inputs is not None else [] - torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set() + engine_capability = EngineCapability._from(engine_capability) + + if torch_executed_modules is not None and torch_executed_modules: + logger.warning( + f"Detected torch_executed_modules was non-empty: {torch_executed_modules}" + "\nThis feature is unimplemented in Torch-TRT Dynamo currently." + ) + + if not isinstance(inputs, collections.abc.Sequence): + inputs = [inputs] + # Prepare torch_trt inputs - input_list = prepare_inputs(input_list) + inputs = prepare_inputs(inputs) device = to_torch_tensorrt_device(device) + enabled_precisions = {dtype._from(p) for p in enabled_precisions} - enabled_precisions = {dtype._from(e) for e in enabled_precisions} + if not isinstance(new_weight_module, ExportedProgram): + raise AssertionError( + f"Input graph should be an ExportedProgram but got type {type(new_weight_module)}" + ) + new_weight_module = new_weight_module.run_decompositions( + get_decompositions(enable_experimental_decompositions) + ) + gm = new_weight_module.module() + logger.debug("Input graph: " + str(gm.graph)) + # Apply lowering on the graph module + torch_inputs = get_torch_inputs(inputs, device) + gm = apply_lowering_passes(gm, torch_inputs) + + logger.debug("Lowered Input graph: " + str(gm.graph)) - # Try to use the old setting if available compilation_options = { - "enabled_precisions": enabled_precisions, + "enabled_precisions": ( + enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS + ), "debug": debug, + "device": device, "workspace_size": workspace_size, "min_block_size": min_block_size, - "torch_executed_ops": torch_executed_ops, + "torch_executed_ops": ( + torch_executed_ops if torch_executed_ops is not None else set() + ), "pass_through_build_failures": pass_through_build_failures, "max_aux_streams": max_aux_streams, "version_compatible": version_compatible, @@ -149,52 +234,84 @@ def refit_single_trt_engine_with_ep( "use_python_runtime": use_python_runtime, "truncate_double": truncate_double, "use_fast_partitioner": use_fast_partitioner, + "num_avg_timing_iters": num_avg_timing_iters, "enable_experimental_decompositions": enable_experimental_decompositions, - "device": device, "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, "refit": refit, "engine_capability": engine_capability, - "num_avg_timing_iters": num_avg_timing_iters, "dla_sram_size": dla_sram_size, "dla_local_dram_size": dla_local_dram_size, "dla_global_dram_size": dla_global_dram_size, + "dryrun": dryrun, + "hardware_compatible": hardware_compatible, } - # Decompose the exported program - exported_program = exported_program.run_decompositions( - get_decompositions(enable_experimental_decompositions) - ) - gm = exported_program.module() - logger.debug("Input graph: " + str(gm.graph)) - - # Apply lowering on the graph module - torch_inputs = get_torch_inputs(input_list, device) - gm = apply_lowering_passes(gm, torch_inputs) - logger.debug("Lowered Input graph: " + str(gm.graph)) - settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) - # Get the refitting mapping - mapping = construct_refit_mapping(gm, input_list, settings) + # Set torch-executed ops + CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) - trt_wt_location = trt.TensorLocation.HOST - refitter = trt.Refitter(engine, TRT_LOGGER) - weight_list = refitter.get_all_weights() + # If specified, try using the fast partitioner and fall back to the global one on failure + if settings.use_fast_partitioner: + try: + partitioned_module, supported_ops = partitioning.fast_partition( + gm, + verbose=settings.debug, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + ) + except torch.fx.passes.splitter_base.FxNetSplitterInternalError: + logger.error( + "Partitioning failed on the subgraph with fast partition. See trace above. " + + "Retrying with global partition.", + exc_info=True, + ) - for layer_name in weight_list: - if layer_name not in mapping: - print(f"{layer_name} is not found in weight mapping") + fast_partitioner_failed = True + settings.use_fast_partitioner = False - # Use Numpy to create weights - weight = mapping[layer_name] - trt_wt_tensor = trt.Weights(trt.DataType.FLOAT, weight.ctypes.data, weight.size) - refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + if not settings.use_fast_partitioner: + partitioned_module, supported_ops = partitioning.global_partition( + gm, + verbose=settings.debug, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + ) - if not refitter.refit_cuda_engine(): - print("Error: failed to refit new weights.") - exit(0) + # Iterate over all components that can be accelerated + # Generate the corresponding TRT Module for those + for name, _ in partitioned_module.named_children(): + new_submodule = getattr(partitioned_module, name) + compiled_submodule = getattr(compiled_module, name) + engine = compiled_submodule.engine - print("Refit Successful") + # Get the submodule inputs for min, opt, max shapes of the graph inputs + submodule_inputs = partitioning.construct_submodule_inputs(new_submodule) + + logger.debug( + "Refitting Submodule name: %s\n", + str(name), + ) + + assert submodule_inputs is not None + # Handle long/double inputs if requested by the user + if settings.truncate_double: + submodule_inputs = repair_double_inputs( + partitioned_module, + new_submodule, + submodule_inputs, + to_torch_device(settings.device), + name, + ) + + # Refit TRT engines from submodule in place + # TODO: Change it to return a new object + refit_single_trt_engine_with_gm( + new_gm=new_submodule, + old_engine=engine, + input_list=submodule_inputs, + settings=settings, + ) From e4cb6697b56c0c49759c64fb39cbac91b43f1acb Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 5 Jun 2024 10:59:49 -0700 Subject: [PATCH 05/70] Support both TRTModules and return a new copy --- py/torch_tensorrt/dynamo/_refit.py | 87 +++++++++++++++++++++++++++--- 1 file changed, 79 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 3da068f256..015dcfc326 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -1,6 +1,7 @@ from __future__ import annotations import collections.abc +import copy import logging import warnings from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union @@ -22,6 +23,10 @@ from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter from torch_tensorrt.dynamo.conversion.truncate_double import repair_double_inputs from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions +from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( + PythonTorchTensorRTModule, +) +from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule from torch_tensorrt.dynamo.utils import ( get_torch_inputs, prepare_inputs, @@ -90,13 +95,15 @@ def construct_refit_mapping( return weight_map -def refit_single_trt_engine_with_gm( +def _refit_single_trt_engine_with_gm( new_gm: torch.fx.GraphModule, old_engine: trt.ICudaEngine, input_list: Tuple[Any, ...], settings: CompilationSettings = CompilationSettings(), ) -> None: - + """ + Refit a TensorRT Engine in place + """ # Get the refitting mapping mapping = construct_refit_mapping(new_gm, input_list, settings) @@ -169,6 +176,7 @@ def refit_module_weights( if debug: set_log_level(logger.parent, logging.DEBUG) + # TODO: Copy the submodule and return a new one if type(compiled_module) == ExportedProgram: compiled_module = compiled_module.module() @@ -280,14 +288,21 @@ def refit_module_weights( min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, ) - + # PytorchTensorRTModule does not support deepcopy + # Create a shallow copy. Replace the TRTModule after + compiled_module = copy.copy(compiled_module) # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those for name, _ in partitioned_module.named_children(): new_submodule = getattr(partitioned_module, name) + # TODO: Copy the submodule and return a new one compiled_submodule = getattr(compiled_module, name) - engine = compiled_submodule.engine - + if isinstance(compiled_submodule, PythonTorchTensorRTModule): + engine = copy_cuda_engine(compiled_submodule.engine) + elif isinstance(compiled_submodule, TorchTensorRTModule): + engine = get_engine_from_TorchTensorRTModule(compiled_submodule) + else: + raise AssertionError("The type of graph module is not supported.") # Get the submodule inputs for min, opt, max shapes of the graph inputs submodule_inputs = partitioning.construct_submodule_inputs(new_submodule) @@ -307,11 +322,67 @@ def refit_module_weights( name, ) - # Refit TRT engines from submodule in place - # TODO: Change it to return a new object - refit_single_trt_engine_with_gm( + _refit_single_trt_engine_with_gm( new_gm=new_submodule, old_engine=engine, input_list=submodule_inputs, settings=settings, ) + + # In TorchTensorRTModule, the original module is intact. Create a new module and assign to the fx.Graph + if isinstance(compiled_submodule, TorchTensorRTModule): + refitteded_submodule = create_new_TorchTensorRTModule( + compiled_submodule, engine=engine, settings=settings + ) + else: + refitteded_submodule = create_new_PythonTorchTensorRTModule( + compiled_submodule, engine=engine, settings=settings + ) + setattr(compiled_module, name, refitteded_submodule) + return compiled_module + + +# Util functions ----------- +import base64 + + +def get_engine_from_TorchTensorRTModule(module: TorchTensorRTModule) -> trt.ICudaEngine: + engine_state = module.get_extra_state() + serialized_engine = base64.b64decode(engine_state[1][0][3]) + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(serialized_engine) + return engine + + +def create_new_TorchTensorRTModule( + module: TorchTensorRTModule, engine: trt.ICudaEngine, settings: object +) -> TorchTensorRTModule: + serialized_engine = engine.serialize() + return TorchTensorRTModule( + serialized_engine=bytes(serialized_engine), + name=module.name, + input_binding_names=module.input_binding_names, + output_binding_names=module.output_binding_names, + target_device=settings.device, + hardware_compatible=module.hardware_compatible, + ) + + +def create_new_PythonTorchTensorRTModule( + module: PythonTorchTensorRTModule, engine: trt.ICudaEngine, settings: object +) -> PythonTorchTensorRTModule: + serialized_engine = engine.serialize() + return PythonTorchTensorRTModule( + engine=bytes(serialized_engine), + input_names=module.input_names, + output_names=module.output_names, + target_device=settings.device, + profiling_enabled=module.profiling_enabled, + ) + + +def copy_cuda_engine(engine: trt.ICudaEngine) -> trt.ICudaEngine: + runtime = trt.Runtime(TRT_LOGGER) + serialized_engine = engine.serialize() + engine = runtime.deserialize_cuda_engine(serialized_engine) + return engine From 388dadcb828254fed2c8ac476e04d684b0a9f788 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 5 Jun 2024 16:01:31 -0700 Subject: [PATCH 06/70] Enabled module saving with settings --- py/torch_tensorrt/_compile.py | 11 ++-- py/torch_tensorrt/dynamo/_compiler.py | 26 +++++++- py/torch_tensorrt/dynamo/_defaults.py | 3 + py/torch_tensorrt/dynamo/_settings.py | 2 + .../dynamo/conversion/_conversion.py | 59 +------------------ 5 files changed, 37 insertions(+), 64 deletions(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 8f3d6269f7..80da0a6d08 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -3,7 +3,7 @@ import collections.abc import logging from enum import Enum -from typing import Any, Callable, List, Optional, Sequence, Set +from typing import Any, Callable, Dict, List, Optional, Sequence, Set import torch import torch.fx @@ -351,9 +351,9 @@ def convert_method_to_trt_engine( torchtrt_inputs = prepare_inputs(inputs) exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs) - return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return] + return dynamo_convert_module_to_trt_engine( exp_program, - inputs=inputs, + inputs=tuple(inputs), enabled_precisions=enabled_precisions_set, **kwargs, ) @@ -398,6 +398,7 @@ def load(file_path: str = "") -> Any: def save( module: Any, file_path: str = "", + extra_files: Optional[Dict[str, Any]] = None, *, output_format: str = "exported_program", inputs: Optional[Sequence[torch.Tensor]] = None, @@ -459,7 +460,7 @@ def save( from torch_tensorrt.dynamo._exporter import export exp_program = export(module, inputs) - torch.export.save(exp_program, file_path) + torch.export.save(exp_program, file_path, extra_files=extra_files) else: from torch._higher_order_ops.torchbind import enable_torchbind_tracing @@ -467,4 +468,4 @@ def save( exp_program = torch.export.export( module, tuple(inputs), strict=False ) - torch.export.save(exp_program, file_path) + torch.export.save(exp_program, file_path, extra_files=extra_files) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 32b0ca65d7..db76aa43a8 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -1,11 +1,15 @@ from __future__ import annotations +import base64 import collections.abc import logging +import os +import pickle import warnings from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union import torch +import torch_tensorrt from torch.export import ExportedProgram from torch.fx.node import Target from torch_tensorrt._Device import Device @@ -73,6 +77,7 @@ def compile( enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, dryrun: bool = _defaults.DRYRUN, hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, + module_save_path: str = _defaults.MODULE_SAVE_PATH, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -162,7 +167,7 @@ def compile( if not isinstance(inputs, collections.abc.Sequence): inputs = [inputs] - + raw_inputs = inputs # Prepare torch_trt inputs inputs = prepare_inputs(inputs) device = to_torch_tensorrt_device(device) @@ -213,11 +218,30 @@ def compile( "dla_global_dram_size": dla_global_dram_size, "dryrun": dryrun, "hardware_compatible": hardware_compatible, + "module_save_path": module_save_path, } settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) trt_gm = compile_module(gm, inputs, settings) + + if settings.refit and not settings.use_python_runtime: + logger.info( + f"Compiled graph module and setting files will be save to {settings.module_save_path}" + ) + save_path = ( + os.path.join(settings.module_save_path, "trt_compiled_module.ep") + if settings.module_save_path[-3:] != ".ep" + else settings.module_save_path + ) + + dumped = pickle.dumps(settings) + dumped_string = base64.b64encode(dumped).decode("utf-8") + extra_files = {"settings": dumped_string} + torch_tensorrt.save( + trt_gm, save_path, inputs=raw_inputs, extra_files=extra_files, retrace=False + ) + return trt_gm diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 7931dc865c..4fdbef29f6 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -1,3 +1,5 @@ +import os + import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype @@ -27,6 +29,7 @@ DRYRUN = False HARDWARE_COMPATIBLE = False SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8, dtype.bf16} +MODULE_SAVE_PATH = os.path.abspath(os.getcwd()) def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 9592bc1fd5..7efe899b7b 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -17,6 +17,7 @@ HARDWARE_COMPATIBLE, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, + MODULE_SAVE_PATH, NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, @@ -98,3 +99,4 @@ class CompilationSettings: dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE dryrun: Union[bool, str] = DRYRUN hardware_compatible: bool = HARDWARE_COMPATIBLE + module_save_path: str = MODULE_SAVE_PATH diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index b12228e09d..5646a63108 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -2,10 +2,8 @@ import io import logging -import warnings from typing import List, Sequence -import numpy as np import tensorrt as trt import torch from torch_tensorrt._Device import Device @@ -13,7 +11,7 @@ from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.conversion._TRTRefittingInterpreter import ( +from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( TRTInterpreter, TRTInterpreterResult, ) @@ -89,61 +87,6 @@ def interpret_module_to_result( return interpreter_result -def get_refit_mapping( - module: torch.fx.GraphModule, - inputs: Sequence[Input], - settings: CompilationSettings = CompilationSettings(), -) -> dict[str, np.ndarray]: - """Interpret an FX module to a TRTInterpreterResult - Args: - module: FX GraphModule to interpret - inputs: Sequence of Tensors representing inputs to the module - settings: Compilation settings - Returns: - TRTInterpreterResult - """ - output_dtypes = infer_module_output_dtypes( - module, - inputs, - settings.device, - truncate_double=settings.truncate_double, - ) - - # Use Interpreter - module_map = { - "SCALE": (trt.IScaleLayer, [("scale", "SCALE"), ("shift", "SHIFT")]), - "CONVOLUTION": ( - trt.IConvolutionLayer, - [("kernel", "KERNEL"), ("bias", "BIAS")], - ), - "CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]), - } - weight_map = {} - interpreter = TRTInterpreter( - module, - inputs, - logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), - output_dtypes=output_dtypes, - compilation_settings=settings, - ) - - net = interpreter.get_network_to_refit() - for i in range(net.num_layers): - layer = net[i] - layer_type: str = layer.type.name - if layer_type in module_map: - layer.__class__ = module_map[layer_type][0] - for weight_type, weight_name in module_map[layer_type][1]: - weight_map[f"{layer.name} {weight_name}"] = layer.__getattribute__( - weight_type - ).copy() - - else: - warnings.warn(f"{layer_type} is not supported yet") - - return weight_map - - def convert_module( module: torch.fx.GraphModule, inputs: Sequence[Input], From f822b2810e8f333967fd2af6e8e17bac9110b4b7 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 5 Jun 2024 16:10:02 -0700 Subject: [PATCH 07/70] Enabled three types of runtime. Build an interface for user to easy refit. Support setting loading --- py/torch_tensorrt/dynamo/_refit.py | 215 +++++++----------- .../refitting/refit_engine_multi_subgraph.py | 112 +++++++++ 2 files changed, 190 insertions(+), 137 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/refitting/refit_engine_multi_subgraph.py diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 015dcfc326..9f79c54bcb 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -3,18 +3,16 @@ import collections.abc import copy import logging +import pickle import warnings -from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Sequence, Tuple import numpy as np import tensorrt as trt import torch from torch.export import ExportedProgram -from torch.fx.node import Target -from torch_tensorrt._Device import Device -from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo import _defaults, partitioning +from torch_tensorrt.dynamo import partitioning from torch_tensorrt.dynamo.conversion import CompilationSettings from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( @@ -28,6 +26,7 @@ ) from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule from torch_tensorrt.dynamo.utils import ( + copy_cuda_engine, get_torch_inputs, prepare_inputs, set_log_level, @@ -127,94 +126,59 @@ def _refit_single_trt_engine_with_gm( print("Refit Successful") -# def refit_module_weights( -# compiled_module: ExportedProgram, -# new_weight_module: ExportedProgram -# ) -> torch.fx.GraphModule: -# pass +def refit_module_weights( + compiled_module_file_path: str, new_weight_module: ExportedProgram, inputs: Any +) -> torch.fx.GraphModule: + """ + Return a copy of compiled_module with refitted weight + """ + settings_wrapper = {"settings": None} + compiled_exp_program = torch.export.load( + compiled_module_file_path, extra_files=settings_wrapper + ) -def refit_module_weights( + decoded_settings = base64.b64decode(settings_wrapper["settings"].encode("utf-8")) + restored_settings = pickle.loads(decoded_settings) + + new_trt_gm = _refit_module_weights( + compiled_module=compiled_exp_program, + new_weight_module=new_weight_module, + inputs=inputs, + settings=restored_settings, + ) + return new_trt_gm + + +def _refit_module_weights( compiled_module: torch.fx.GraphModule | ExportedProgram, new_weight_module: ExportedProgram, inputs: Tuple[Any, ...], - *, - device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE, - disable_tf32: bool = _defaults.DISABLE_TF32, - sparse_weights: bool = _defaults.SPARSE_WEIGHTS, - enabled_precisions: ( - Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] - ) = _defaults.ENABLED_PRECISIONS, - engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - refit: bool = _defaults.REFIT, - debug: bool = _defaults.DEBUG, - num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, - workspace_size: int = _defaults.WORKSPACE_SIZE, - dla_sram_size: int = _defaults.DLA_SRAM_SIZE, - dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, - dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, - truncate_double: bool = _defaults.TRUNCATE_DOUBLE, - require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, - min_block_size: int = _defaults.MIN_BLOCK_SIZE, - torch_executed_ops: Optional[Collection[Target]] = None, - torch_executed_modules: Optional[List[str]] = None, - pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, - max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, - version_compatible: bool = _defaults.VERSION_COMPATIBLE, - optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, - use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME, - use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, - enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, - dryrun: bool = _defaults.DRYRUN, - hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, - **kwargs: Any, + settings: Any, ) -> torch.fx.GraphModule: """ Refit a compiled graph module with ExportedProgram """ - if debug: + if settings.debug: set_log_level(logger.parent, logging.DEBUG) - # TODO: Copy the submodule and return a new one - if type(compiled_module) == ExportedProgram: - compiled_module = compiled_module.module() - - if "truncate_long_and_double" in kwargs.keys(): - if truncate_double is not _defaults.TRUNCATE_DOUBLE: - raise ValueError( - 'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"' - ) - else: - truncate_double = kwargs["truncate_long_and_double"] - warnings.warn( - 'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version', - DeprecationWarning, - stacklevel=2, - ) - - engine_capability = EngineCapability._from(engine_capability) - - if torch_executed_modules is not None and torch_executed_modules: - logger.warning( - f"Detected torch_executed_modules was non-empty: {torch_executed_modules}" - "\nThis feature is unimplemented in Torch-TRT Dynamo currently." - ) - if not isinstance(inputs, collections.abc.Sequence): inputs = [inputs] + if isinstance(compiled_module, ExportedProgram): + compiled_module = compiled_module.module() + # Prepare torch_trt inputs inputs = prepare_inputs(inputs) - device = to_torch_tensorrt_device(device) - enabled_precisions = {dtype._from(p) for p in enabled_precisions} + device = to_torch_tensorrt_device(settings.device) if not isinstance(new_weight_module, ExportedProgram): raise AssertionError( f"Input graph should be an ExportedProgram but got type {type(new_weight_module)}" ) new_weight_module = new_weight_module.run_decompositions( - get_decompositions(enable_experimental_decompositions) + get_decompositions(settings.enable_experimental_decompositions) ) gm = new_weight_module.module() logger.debug("Input graph: " + str(gm.graph)) @@ -222,41 +186,6 @@ def refit_module_weights( torch_inputs = get_torch_inputs(inputs, device) gm = apply_lowering_passes(gm, torch_inputs) - logger.debug("Lowered Input graph: " + str(gm.graph)) - - compilation_options = { - "enabled_precisions": ( - enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS - ), - "debug": debug, - "device": device, - "workspace_size": workspace_size, - "min_block_size": min_block_size, - "torch_executed_ops": ( - torch_executed_ops if torch_executed_ops is not None else set() - ), - "pass_through_build_failures": pass_through_build_failures, - "max_aux_streams": max_aux_streams, - "version_compatible": version_compatible, - "optimization_level": optimization_level, - "use_python_runtime": use_python_runtime, - "truncate_double": truncate_double, - "use_fast_partitioner": use_fast_partitioner, - "num_avg_timing_iters": num_avg_timing_iters, - "enable_experimental_decompositions": enable_experimental_decompositions, - "require_full_compilation": require_full_compilation, - "disable_tf32": disable_tf32, - "sparse_weights": sparse_weights, - "refit": refit, - "engine_capability": engine_capability, - "dla_sram_size": dla_sram_size, - "dla_local_dram_size": dla_local_dram_size, - "dla_global_dram_size": dla_global_dram_size, - "dryrun": dryrun, - "hardware_compatible": hardware_compatible, - } - - settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) # Set torch-executed ops @@ -296,13 +225,23 @@ def refit_module_weights( for name, _ in partitioned_module.named_children(): new_submodule = getattr(partitioned_module, name) # TODO: Copy the submodule and return a new one - compiled_submodule = getattr(compiled_module, name) - if isinstance(compiled_submodule, PythonTorchTensorRTModule): - engine = copy_cuda_engine(compiled_submodule.engine) - elif isinstance(compiled_submodule, TorchTensorRTModule): - engine = get_engine_from_TorchTensorRTModule(compiled_submodule) - else: - raise AssertionError("The type of graph module is not supported.") + inline_module = False + try: + compiled_submodule = getattr(compiled_module, name) + if isinstance(compiled_submodule, PythonTorchTensorRTModule): + engine = copy_cuda_engine(compiled_submodule.engine) + elif isinstance(compiled_submodule, TorchTensorRTModule): + engine_state = compiled_submodule.get_extra_state() + encoded_engine = engine_state[1][0][3] + engine = get_engine_from_encoded_engine(encoded_engine) + else: + raise AssertionError("The type of graph module is not supported.") + except AttributeError: + inline_module = True + inline_engine = getattr(compiled_module, f"{name}_engine") + engine_info = inline_engine.__getstate__()[0] + engine = get_engine_from_encoded_engine(engine_info[3]) + # Get the submodule inputs for min, opt, max shapes of the graph inputs submodule_inputs = partitioning.construct_submodule_inputs(new_submodule) @@ -328,17 +267,29 @@ def refit_module_weights( input_list=submodule_inputs, settings=settings, ) - - # In TorchTensorRTModule, the original module is intact. Create a new module and assign to the fx.Graph - if isinstance(compiled_submodule, TorchTensorRTModule): - refitteded_submodule = create_new_TorchTensorRTModule( - compiled_submodule, engine=engine, settings=settings + serialized_engine = bytes(engine.serialize()) + if inline_module: + new_engine_info = list(engine_info) + new_engine_info[3] = serialized_engine + refitted_inline_engine = torch.classes.tensorrt.Engine( + tuple(new_engine_info) ) + setattr(compiled_module, f"{name}_engine", refitted_inline_engine) else: - refitteded_submodule = create_new_PythonTorchTensorRTModule( - compiled_submodule, engine=engine, settings=settings - ) - setattr(compiled_module, name, refitteded_submodule) + # In TorchTensorRTModule, the original module is intact. Create a new module and assign to the fx.Graph + if isinstance(compiled_submodule, TorchTensorRTModule): + refitteded_submodule = create_new_TorchTensorRTModule( + compiled_submodule, + serialized_engine=serialized_engine, + settings=settings, + ) + else: + refitteded_submodule = create_new_PythonTorchTensorRTModule( + compiled_submodule, + serialized_engine=serialized_engine, + settings=settings, + ) + setattr(compiled_module, name, refitteded_submodule) return compiled_module @@ -346,20 +297,18 @@ def refit_module_weights( import base64 -def get_engine_from_TorchTensorRTModule(module: TorchTensorRTModule) -> trt.ICudaEngine: - engine_state = module.get_extra_state() - serialized_engine = base64.b64decode(engine_state[1][0][3]) +def get_engine_from_encoded_engine(encoded_engine: bytes) -> trt.ICudaEngine: + serialized_engine = base64.b64decode(encoded_engine) runtime = trt.Runtime(TRT_LOGGER) engine = runtime.deserialize_cuda_engine(serialized_engine) return engine def create_new_TorchTensorRTModule( - module: TorchTensorRTModule, engine: trt.ICudaEngine, settings: object + module: TorchTensorRTModule, serialized_engine: trt.ICudaEngine, settings: object ) -> TorchTensorRTModule: - serialized_engine = engine.serialize() return TorchTensorRTModule( - serialized_engine=bytes(serialized_engine), + serialized_engine=serialized_engine, name=module.name, input_binding_names=module.input_binding_names, output_binding_names=module.output_binding_names, @@ -369,20 +318,12 @@ def create_new_TorchTensorRTModule( def create_new_PythonTorchTensorRTModule( - module: PythonTorchTensorRTModule, engine: trt.ICudaEngine, settings: object + module: PythonTorchTensorRTModule, serialized_engine: bytes, settings: object ) -> PythonTorchTensorRTModule: - serialized_engine = engine.serialize() return PythonTorchTensorRTModule( - engine=bytes(serialized_engine), + engine=serialized_engine, input_names=module.input_names, output_names=module.output_names, target_device=settings.device, profiling_enabled=module.profiling_enabled, ) - - -def copy_cuda_engine(engine: trt.ICudaEngine) -> trt.ICudaEngine: - runtime = trt.Runtime(TRT_LOGGER) - serialized_engine = engine.serialize() - engine = runtime.deserialize_cuda_engine(serialized_engine) - return engine diff --git a/py/torch_tensorrt/dynamo/refitting/refit_engine_multi_subgraph.py b/py/torch_tensorrt/dynamo/refitting/refit_engine_multi_subgraph.py new file mode 100644 index 0000000000..3b2febb961 --- /dev/null +++ b/py/torch_tensorrt/dynamo/refitting/refit_engine_multi_subgraph.py @@ -0,0 +1,112 @@ +import numpy as np +import torch +import torch_tensorrt as trt +import torchvision.models as models + +# from torch import nn +from torch_tensorrt.dynamo._refit import refit_module_weights + +np.random.seed(0) +torch.manual_seed(0) + + +inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] + +# Small Toy Model --------------------------------------- + + +# class net(nn.Module): +# def __init__(self): +# super().__init__() +# self.conv1 = nn.Conv2d(3, 12, 3, padding=1) +# self.bn = nn.BatchNorm2d(12) +# self.relu = nn.ReLU() + +# def forward(self, x): +# x = self.conv1(x) +# x = self.bn(x) +# x = self.relu(x) +# return x + + +# model = net().eval().to("cuda") +# np.random.seed(1) +# model2 = net().eval().to("cuda") + +# Resnet 18 -------------------------------------------- + +model = models.resnet18(pretrained=False).eval().to("cuda") +model2 = models.resnet18(pretrained=True).eval().to("cuda") + +# Bert ----------------------------------------------- + +# from transformers import BertModel +# inputs = [ +# torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), +# ] +# model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda") + + +enabled_precisions = {torch.float} +debug = True +workspace_size = 20 << 30 +min_block_size = 0 + + +exp_program = torch.export.export(model, tuple(inputs)) +exp_program2 = torch.export.export(model2, tuple(inputs)) + +use_python_runtime = False + +trt_gm = trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + refit=True, + engine_save_path="/home/cehongw/Desktop/torch-trt/TensorRT/py/torch_tensorrt/dynamo/refitting/", +) # Output is a torch.fx.GraphModule + + +output = trt_gm.forward(*inputs) +print(output[0].sum().cpu().item()) +extra_files_loaded = {"settings": None} +path = "/home/cehongw/Desktop/torch-trt/TensorRT/py/torch_tensorrt/dynamo/refitting/trt_compiled_module.ep" + +new_trt_gm = refit_module_weights( + path, + new_weight_module=exp_program2, + inputs=inputs, +) + + +# compiled_exp_program = torch.export.load("/home/cehongw/Desktop/torch-trt/TensorRT/py/torch_tensorrt/dynamo/refitting/trt_compiled_module.ep" +# , extra_files=extra_files_loaded) + + +# decoded = base64.b64decode(extra_files_loaded["settings"].encode('utf-8')) +# restored_settings = pickle.loads(decoded) + +# new_trt_gm = refit_module_weights( +# compiled_module=compiled_exp_program, +# new_weight_module=exp_program2, +# inputs=inputs, +# settings=restored_settings, +# ) + +output_refit = new_trt_gm.forward(*inputs) +print(output_refit[0].sum().cpu().item()) + +pytorch_output = model2(*inputs)[0] +print(pytorch_output.sum().cpu().item()) + +print((output_refit - pytorch_output).mean()) +print() + + +# Iterate over all layers and print weights +# for layer_name in refitter.get_all_weights(): +# # Print kernel weights +# print_layer_weights(refitter, layer_name) From 56eb54934fe103818f85cf5f5c1d983b6b2a64c6 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 5 Jun 2024 16:41:50 -0700 Subject: [PATCH 08/70] Reorganized the code --- py/torch_tensorrt/dynamo/_refit.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 9f79c54bcb..f3c1be0ae9 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -26,7 +26,6 @@ ) from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule from torch_tensorrt.dynamo.utils import ( - copy_cuda_engine, get_torch_inputs, prepare_inputs, set_log_level, @@ -172,7 +171,7 @@ def _refit_module_weights( # Prepare torch_trt inputs inputs = prepare_inputs(inputs) device = to_torch_tensorrt_device(settings.device) - + runtime = trt.Runtime(TRT_LOGGER) if not isinstance(new_weight_module, ExportedProgram): raise AssertionError( f"Input graph should be an ExportedProgram but got type {type(new_weight_module)}" @@ -229,18 +228,18 @@ def _refit_module_weights( try: compiled_submodule = getattr(compiled_module, name) if isinstance(compiled_submodule, PythonTorchTensorRTModule): - engine = copy_cuda_engine(compiled_submodule.engine) + engine = copy_cuda_engine(compiled_submodule.engine, runtime) elif isinstance(compiled_submodule, TorchTensorRTModule): engine_state = compiled_submodule.get_extra_state() encoded_engine = engine_state[1][0][3] - engine = get_engine_from_encoded_engine(encoded_engine) + engine = get_engine_from_encoded_engine(encoded_engine, runtime) else: raise AssertionError("The type of graph module is not supported.") except AttributeError: inline_module = True inline_engine = getattr(compiled_module, f"{name}_engine") engine_info = inline_engine.__getstate__()[0] - engine = get_engine_from_encoded_engine(engine_info[3]) + engine = get_engine_from_encoded_engine(engine_info[3], runtime) # Get the submodule inputs for min, opt, max shapes of the graph inputs submodule_inputs = partitioning.construct_submodule_inputs(new_submodule) @@ -297,9 +296,10 @@ def _refit_module_weights( import base64 -def get_engine_from_encoded_engine(encoded_engine: bytes) -> trt.ICudaEngine: +def get_engine_from_encoded_engine( + encoded_engine: bytes, runtime: trt.Runtime +) -> trt.ICudaEngine: serialized_engine = base64.b64decode(encoded_engine) - runtime = trt.Runtime(TRT_LOGGER) engine = runtime.deserialize_cuda_engine(serialized_engine) return engine @@ -327,3 +327,9 @@ def create_new_PythonTorchTensorRTModule( target_device=settings.device, profiling_enabled=module.profiling_enabled, ) + + +def copy_cuda_engine(engine: trt.ICudaEngine, runtime: trt.Runtime) -> trt.ICudaEngine: + serialized_engine = engine.serialize() + engine = runtime.deserialize_cuda_engine(serialized_engine) + return engine From 94483a87ac54837f477e4205ad757ca0f92f4625 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 6 Jun 2024 12:53:56 -0700 Subject: [PATCH 09/70] Added weight type check and number check --- py/torch_tensorrt/dynamo/_refit.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index f3c1be0ae9..4159916e79 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -83,9 +83,11 @@ def construct_refit_mapping( if layer_type in MODULE_MAP: layer.__class__ = MODULE_MAP[layer_type][0] for weight_type, weight_name in MODULE_MAP[layer_type][1]: - weight_map[f"{layer.name} {weight_name}"] = layer.__getattribute__( - weight_type - ).copy() + weight = layer.__getattribute__(weight_type).copy() + weight_map[f"{layer.name} {weight_name}"] = ( + weight, + layer.get_output_type(0), + ) else: warnings.warn(f"{layer_type} is not supported yet") @@ -104,6 +106,7 @@ def _refit_single_trt_engine_with_gm( """ # Get the refitting mapping mapping = construct_refit_mapping(new_gm, input_list, settings) + refitted = set() trt_wt_location = trt.TensorLocation.HOST refitter = trt.Refitter(old_engine, TRT_LOGGER) @@ -111,12 +114,17 @@ def _refit_single_trt_engine_with_gm( for layer_name in weight_list: if layer_name not in mapping: - print(f"{layer_name} is not found in weight mapping") - + raise AssertionError(f"{layer_name} is not found in weight mapping") # Use Numpy to create weights - weight = mapping[layer_name] - trt_wt_tensor = trt.Weights(trt.DataType.FLOAT, weight.ctypes.data, weight.size) + weight, datatype = mapping[layer_name] + trt_wt_tensor = trt.Weights( + datatype, weight.ctypes.data, weight.size + ) # TODO: Support different types of dtype refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + refitted.add(layer_name) + + if len(refitted) != len(weight_list): + raise AssertionError("Not all weights have been refitted") if not refitter.refit_cuda_engine(): print("Error: failed to refit new weights.") @@ -158,6 +166,7 @@ def _refit_module_weights( """ Refit a compiled graph module with ExportedProgram """ + # Check the setting to be uniform if settings.debug: set_log_level(logger.parent, logging.DEBUG) @@ -216,8 +225,13 @@ def _refit_module_weights( min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, ) + + # TODO: Check whether two modules have the same subcomponents + # 1. Check the number of partitions and name + # 2. (future) Check the hash of source fx.Graph and new fx.Graph + # PytorchTensorRTModule does not support deepcopy - # Create a shallow copy. Replace the TRTModule after + # Create a shallow copy. Replace the TRTModule after. TODO: Rethin the copy compiled_module = copy.copy(compiled_module) # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those From 4ba84b7fe52261e6d848718b9aaeaabc0d47b32c Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 6 Jun 2024 15:29:13 -0700 Subject: [PATCH 10/70] Deleted the save in compilation --- py/torch_tensorrt/_compile.py | 7 +- py/torch_tensorrt/dynamo/_compiler.py | 150 +++++++++++++++++++++----- py/torch_tensorrt/dynamo/_defaults.py | 3 - 3 files changed, 128 insertions(+), 32 deletions(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 80da0a6d08..ce966a2609 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -3,7 +3,7 @@ import collections.abc import logging from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Sequence, Set +from typing import Any, Callable, List, Optional, Sequence, Set import torch import torch.fx @@ -398,7 +398,6 @@ def load(file_path: str = "") -> Any: def save( module: Any, file_path: str = "", - extra_files: Optional[Dict[str, Any]] = None, *, output_format: str = "exported_program", inputs: Optional[Sequence[torch.Tensor]] = None, @@ -460,7 +459,7 @@ def save( from torch_tensorrt.dynamo._exporter import export exp_program = export(module, inputs) - torch.export.save(exp_program, file_path, extra_files=extra_files) + torch.export.save(exp_program, file_path) else: from torch._higher_order_ops.torchbind import enable_torchbind_tracing @@ -468,4 +467,4 @@ def save( exp_program = torch.export.export( module, tuple(inputs), strict=False ) - torch.export.save(exp_program, file_path, extra_files=extra_files) + torch.export.save(exp_program, file_path) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index db76aa43a8..be9c57a50b 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -1,15 +1,11 @@ from __future__ import annotations -import base64 import collections.abc import logging -import os -import pickle import warnings from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union import torch -import torch_tensorrt from torch.export import ExportedProgram from torch.fx.node import Target from torch_tensorrt._Device import Device @@ -26,6 +22,7 @@ CompilationSettings, UnsupportedOperatorException, convert_module, + get_refit_mapping, interpret_module_to_result, repair_double_inputs, ) @@ -77,7 +74,6 @@ def compile( enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, dryrun: bool = _defaults.DRYRUN, hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, - module_save_path: str = _defaults.MODULE_SAVE_PATH, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -167,7 +163,7 @@ def compile( if not isinstance(inputs, collections.abc.Sequence): inputs = [inputs] - raw_inputs = inputs + # Prepare torch_trt inputs inputs = prepare_inputs(inputs) device = to_torch_tensorrt_device(device) @@ -218,30 +214,11 @@ def compile( "dla_global_dram_size": dla_global_dram_size, "dryrun": dryrun, "hardware_compatible": hardware_compatible, - "module_save_path": module_save_path, } settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) trt_gm = compile_module(gm, inputs, settings) - - if settings.refit and not settings.use_python_runtime: - logger.info( - f"Compiled graph module and setting files will be save to {settings.module_save_path}" - ) - save_path = ( - os.path.join(settings.module_save_path, "trt_compiled_module.ep") - if settings.module_save_path[-3:] != ".ep" - else settings.module_save_path - ) - - dumped = pickle.dumps(settings) - dumped_string = base64.b64encode(dumped).decode("utf-8") - extra_files = {"settings": dumped_string} - torch_tensorrt.save( - trt_gm, save_path, inputs=raw_inputs, extra_files=extra_files, retrace=False - ) - return trt_gm @@ -633,3 +610,126 @@ def convert_module_to_trt_engine( engine_bytearray = engine_bytes.getvalue() return engine_bytearray + + +def refit_trt_engine_from_module( + exported_program: ExportedProgram, + inputs: Tuple[Any, ...], + engine: object, + *, + enabled_precisions: ( + Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] + ) = _defaults.ENABLED_PRECISIONS, + debug: bool = _defaults.DEBUG, + workspace_size: int = _defaults.WORKSPACE_SIZE, + min_block_size: int = _defaults.MIN_BLOCK_SIZE, + torch_executed_ops: Optional[Set[str]] = None, + pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, + max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, + version_compatible: bool = _defaults.VERSION_COMPATIBLE, + optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, + use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME, + truncate_double: bool = _defaults.TRUNCATE_DOUBLE, + use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, + enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + device: Device = Device._current_device(), + require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, + disable_tf32: bool = _defaults.DISABLE_TF32, + sparse_weights: bool = _defaults.SPARSE_WEIGHTS, + refit: bool = _defaults.REFIT, + engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, + num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, + dla_sram_size: int = _defaults.DLA_SRAM_SIZE, + dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, + dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, + **kwargs: Any, +) -> None: + + if "truncate_long_and_double" in kwargs.keys(): + if truncate_double is not _defaults.TRUNCATE_DOUBLE: + raise ValueError( + 'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"' + ) + else: + truncate_double = kwargs["truncate_long_and_double"] + warnings.warn( + 'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version', + DeprecationWarning, + stacklevel=2, + ) + + input_list = list(inputs) if inputs is not None else [] + torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set() + # Prepare torch_trt inputs + input_list = prepare_inputs(input_list) + device = to_torch_tensorrt_device(device) + + enabled_precisions = {dtype._from(e) for e in enabled_precisions} + + compilation_options = { + "enabled_precisions": enabled_precisions, + "debug": debug, + "workspace_size": workspace_size, + "min_block_size": min_block_size, + "torch_executed_ops": torch_executed_ops, + "pass_through_build_failures": pass_through_build_failures, + "max_aux_streams": max_aux_streams, + "version_compatible": version_compatible, + "optimization_level": optimization_level, + "use_python_runtime": use_python_runtime, + "truncate_double": truncate_double, + "use_fast_partitioner": use_fast_partitioner, + "enable_experimental_decompositions": enable_experimental_decompositions, + "device": device, + "require_full_compilation": require_full_compilation, + "disable_tf32": disable_tf32, + "sparse_weights": sparse_weights, + "refit": refit, + "engine_capability": engine_capability, + "num_avg_timing_iters": num_avg_timing_iters, + "dla_sram_size": dla_sram_size, + "dla_local_dram_size": dla_local_dram_size, + "dla_global_dram_size": dla_global_dram_size, + } + + # Decompose the exported program + exported_program = exported_program.run_decompositions( + get_decompositions(enable_experimental_decompositions) + ) + gm = exported_program.module() + logger.debug("Input graph: " + str(gm.graph)) + + # Apply lowering on the graph module + torch_inputs = get_torch_inputs(input_list, device) + gm = apply_lowering_passes(gm, torch_inputs) + logger.debug("Lowered Input graph: " + str(gm.graph)) + + settings = CompilationSettings(**compilation_options) + logger.info("Compilation Settings: %s\n", settings) + + # Get the refitting mapping + import tensorrt as trt + + mapping = get_refit_mapping(gm, input_list, settings) + + TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + trt_wt_location = trt.TensorLocation.HOST + + refitter = trt.Refitter(engine, TRT_LOGGER) + + weight_list = refitter.get_all_weights() + + for layer_name in weight_list: + if layer_name not in mapping: + print(f"{layer_name} is not found in weight mapping") + + # Use Numpy to create weights + weight = mapping[layer_name] + trt_wt_tensor = trt.Weights(trt.DataType.FLOAT, weight.ctypes.data, weight.size) + refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + + if not refitter.refit_cuda_engine(): + print("Error: failed to refit new weights.") + exit(0) + + print("Refit Successful") diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 4fdbef29f6..7931dc865c 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -1,5 +1,3 @@ -import os - import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype @@ -29,7 +27,6 @@ DRYRUN = False HARDWARE_COMPATIBLE = False SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8, dtype.bf16} -MODULE_SAVE_PATH = os.path.abspath(os.getcwd()) def default_device() -> Device: From ee6f12328f81f9cd28785702bf1dc7402f0814e9 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 6 Jun 2024 15:33:28 -0700 Subject: [PATCH 11/70] deleted more compilation save --- py/torch_tensorrt/dynamo/_compiler.py | 124 -------------------------- py/torch_tensorrt/dynamo/_settings.py | 2 - 2 files changed, 126 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index be9c57a50b..32b0ca65d7 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -22,7 +22,6 @@ CompilationSettings, UnsupportedOperatorException, convert_module, - get_refit_mapping, interpret_module_to_result, repair_double_inputs, ) @@ -610,126 +609,3 @@ def convert_module_to_trt_engine( engine_bytearray = engine_bytes.getvalue() return engine_bytearray - - -def refit_trt_engine_from_module( - exported_program: ExportedProgram, - inputs: Tuple[Any, ...], - engine: object, - *, - enabled_precisions: ( - Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] - ) = _defaults.ENABLED_PRECISIONS, - debug: bool = _defaults.DEBUG, - workspace_size: int = _defaults.WORKSPACE_SIZE, - min_block_size: int = _defaults.MIN_BLOCK_SIZE, - torch_executed_ops: Optional[Set[str]] = None, - pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, - max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, - version_compatible: bool = _defaults.VERSION_COMPATIBLE, - optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, - use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME, - truncate_double: bool = _defaults.TRUNCATE_DOUBLE, - use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, - enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, - device: Device = Device._current_device(), - require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, - disable_tf32: bool = _defaults.DISABLE_TF32, - sparse_weights: bool = _defaults.SPARSE_WEIGHTS, - refit: bool = _defaults.REFIT, - engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, - dla_sram_size: int = _defaults.DLA_SRAM_SIZE, - dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, - dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, - **kwargs: Any, -) -> None: - - if "truncate_long_and_double" in kwargs.keys(): - if truncate_double is not _defaults.TRUNCATE_DOUBLE: - raise ValueError( - 'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"' - ) - else: - truncate_double = kwargs["truncate_long_and_double"] - warnings.warn( - 'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version', - DeprecationWarning, - stacklevel=2, - ) - - input_list = list(inputs) if inputs is not None else [] - torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set() - # Prepare torch_trt inputs - input_list = prepare_inputs(input_list) - device = to_torch_tensorrt_device(device) - - enabled_precisions = {dtype._from(e) for e in enabled_precisions} - - compilation_options = { - "enabled_precisions": enabled_precisions, - "debug": debug, - "workspace_size": workspace_size, - "min_block_size": min_block_size, - "torch_executed_ops": torch_executed_ops, - "pass_through_build_failures": pass_through_build_failures, - "max_aux_streams": max_aux_streams, - "version_compatible": version_compatible, - "optimization_level": optimization_level, - "use_python_runtime": use_python_runtime, - "truncate_double": truncate_double, - "use_fast_partitioner": use_fast_partitioner, - "enable_experimental_decompositions": enable_experimental_decompositions, - "device": device, - "require_full_compilation": require_full_compilation, - "disable_tf32": disable_tf32, - "sparse_weights": sparse_weights, - "refit": refit, - "engine_capability": engine_capability, - "num_avg_timing_iters": num_avg_timing_iters, - "dla_sram_size": dla_sram_size, - "dla_local_dram_size": dla_local_dram_size, - "dla_global_dram_size": dla_global_dram_size, - } - - # Decompose the exported program - exported_program = exported_program.run_decompositions( - get_decompositions(enable_experimental_decompositions) - ) - gm = exported_program.module() - logger.debug("Input graph: " + str(gm.graph)) - - # Apply lowering on the graph module - torch_inputs = get_torch_inputs(input_list, device) - gm = apply_lowering_passes(gm, torch_inputs) - logger.debug("Lowered Input graph: " + str(gm.graph)) - - settings = CompilationSettings(**compilation_options) - logger.info("Compilation Settings: %s\n", settings) - - # Get the refitting mapping - import tensorrt as trt - - mapping = get_refit_mapping(gm, input_list, settings) - - TRT_LOGGER = trt.Logger(trt.Logger.ERROR) - trt_wt_location = trt.TensorLocation.HOST - - refitter = trt.Refitter(engine, TRT_LOGGER) - - weight_list = refitter.get_all_weights() - - for layer_name in weight_list: - if layer_name not in mapping: - print(f"{layer_name} is not found in weight mapping") - - # Use Numpy to create weights - weight = mapping[layer_name] - trt_wt_tensor = trt.Weights(trt.DataType.FLOAT, weight.ctypes.data, weight.size) - refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) - - if not refitter.refit_cuda_engine(): - print("Error: failed to refit new weights.") - exit(0) - - print("Refit Successful") diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 7efe899b7b..9592bc1fd5 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -17,7 +17,6 @@ HARDWARE_COMPATIBLE, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, - MODULE_SAVE_PATH, NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, @@ -99,4 +98,3 @@ class CompilationSettings: dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE dryrun: Union[bool, str] = DRYRUN hardware_compatible: bool = HARDWARE_COMPATIBLE - module_save_path: str = MODULE_SAVE_PATH From bc23ddb82aa632ef80d4d385597523fe29648c6a Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 6 Jun 2024 18:16:40 -0700 Subject: [PATCH 12/70] Supported different dtypes. Support all possible layers. Support deep copy. --- py/torch_tensorrt/dynamo/_refit.py | 126 +++++++----------- .../dynamo/conversion/_conversion.py | 2 + .../runtime/_PythonTorchTensorRTModule.py | 9 ++ .../dynamo/runtime/_TorchTensorRTModule.py | 3 +- 4 files changed, 63 insertions(+), 77 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 4159916e79..cbffc511e9 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -56,6 +56,10 @@ def construct_refit_mapping( trt.IConvolutionLayer, [("kernel", "KERNEL"), ("bias", "BIAS")], ), + "DECONVOLUTION": ( + trt.IDeconvolutionLayer, + [("kernel", "KERNEL"), ("bias", "BIAS")], + ), "CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]), } @@ -117,9 +121,7 @@ def _refit_single_trt_engine_with_gm( raise AssertionError(f"{layer_name} is not found in weight mapping") # Use Numpy to create weights weight, datatype = mapping[layer_name] - trt_wt_tensor = trt.Weights( - datatype, weight.ctypes.data, weight.size - ) # TODO: Support different types of dtype + trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size) refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) refitted.add(layer_name) @@ -161,21 +163,28 @@ def _refit_module_weights( compiled_module: torch.fx.GraphModule | ExportedProgram, new_weight_module: ExportedProgram, inputs: Tuple[Any, ...], - settings: Any, + settings: Any = None, ) -> torch.fx.GraphModule: """ Refit a compiled graph module with ExportedProgram """ - # Check the setting to be uniform + if isinstance(compiled_module, ExportedProgram): + compiled_module = compiled_module.module() - if settings.debug: - set_log_level(logger.parent, logging.DEBUG) + compiled_module = copy.deepcopy(compiled_module) - if not isinstance(inputs, collections.abc.Sequence): - inputs = [inputs] + # Get the settings and check the setting to be uniform + if not settings: + for name, submodule in compiled_module.named_children(): + if settings is not None: + assert settings == submodule.settings + settings = submodule.settings - if isinstance(compiled_module, ExportedProgram): - compiled_module = compiled_module.module() + if settings.debug: + set_log_level(logger.parent, logging.DEBUG) + + if not isinstance(inputs, collections.abc.Sequence): + inputs = [inputs] # Prepare torch_trt inputs inputs = prepare_inputs(inputs) @@ -226,43 +235,45 @@ def _refit_module_weights( torch_executed_ops=settings.torch_executed_ops, ) - # TODO: Check whether two modules have the same subcomponents + # Check whether two modules have the same subcomponents # 1. Check the number of partitions and name - # 2. (future) Check the hash of source fx.Graph and new fx.Graph + assert [sm[0] for sm in partitioned_module.named_children()] == [ + sm[0] for sm in compiled_module.named_children() + ] + # 2. TODO: Check the hash of source fx.Graph and new fx.Graph - # PytorchTensorRTModule does not support deepcopy - # Create a shallow copy. Replace the TRTModule after. TODO: Rethin the copy - compiled_module = copy.copy(compiled_module) # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those for name, _ in partitioned_module.named_children(): new_submodule = getattr(partitioned_module, name) - # TODO: Copy the submodule and return a new one + # Extract engine from the submodule inline_module = False try: compiled_submodule = getattr(compiled_module, name) if isinstance(compiled_submodule, PythonTorchTensorRTModule): - engine = copy_cuda_engine(compiled_submodule.engine, runtime) + engine = compiled_submodule.engine elif isinstance(compiled_submodule, TorchTensorRTModule): - engine_state = compiled_submodule.get_extra_state() - encoded_engine = engine_state[1][0][3] - engine = get_engine_from_encoded_engine(encoded_engine, runtime) - else: - raise AssertionError("The type of graph module is not supported.") + engine_info = compiled_submodule.engine.__getstate__()[0] + engine = get_engine_from_encoded_engine(engine_info[3], runtime) + except AttributeError: - inline_module = True - inline_engine = getattr(compiled_module, f"{name}_engine") - engine_info = inline_engine.__getstate__()[0] - engine = get_engine_from_encoded_engine(engine_info[3], runtime) + try: + inline_module = True + inline_engine = getattr(compiled_module, f"{name}_engine") + engine_info = inline_engine.__getstate__()[0] + engine = get_engine_from_encoded_engine(engine_info[3], runtime) + + except AttributeError: + raise AssertionError( + "The type of graph module is not supported for refitting or two compiled modules do not match." + ) # Get the submodule inputs for min, opt, max shapes of the graph inputs submodule_inputs = partitioning.construct_submodule_inputs(new_submodule) - logger.debug( "Refitting Submodule name: %s\n", str(name), ) - assert submodule_inputs is not None # Handle long/double inputs if requested by the user if settings.truncate_double: @@ -280,29 +291,23 @@ def _refit_module_weights( input_list=submodule_inputs, settings=settings, ) - serialized_engine = bytes(engine.serialize()) + if inline_module: + serialized_engine = bytes(engine.serialize()) new_engine_info = list(engine_info) new_engine_info[3] = serialized_engine refitted_inline_engine = torch.classes.tensorrt.Engine( tuple(new_engine_info) ) setattr(compiled_module, f"{name}_engine", refitted_inline_engine) - else: - # In TorchTensorRTModule, the original module is intact. Create a new module and assign to the fx.Graph - if isinstance(compiled_submodule, TorchTensorRTModule): - refitteded_submodule = create_new_TorchTensorRTModule( - compiled_submodule, - serialized_engine=serialized_engine, - settings=settings, - ) - else: - refitteded_submodule = create_new_PythonTorchTensorRTModule( - compiled_submodule, - serialized_engine=serialized_engine, - settings=settings, - ) - setattr(compiled_module, name, refitteded_submodule) + + elif isinstance(compiled_submodule, TorchTensorRTModule): + serialized_engine = bytes(engine.serialize()) + new_engine_info = list(engine_info) + new_engine_info[3] = serialized_engine + refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) + compiled_submodule.engine = refitted_engine + return compiled_module @@ -316,34 +321,3 @@ def get_engine_from_encoded_engine( serialized_engine = base64.b64decode(encoded_engine) engine = runtime.deserialize_cuda_engine(serialized_engine) return engine - - -def create_new_TorchTensorRTModule( - module: TorchTensorRTModule, serialized_engine: trt.ICudaEngine, settings: object -) -> TorchTensorRTModule: - return TorchTensorRTModule( - serialized_engine=serialized_engine, - name=module.name, - input_binding_names=module.input_binding_names, - output_binding_names=module.output_binding_names, - target_device=settings.device, - hardware_compatible=module.hardware_compatible, - ) - - -def create_new_PythonTorchTensorRTModule( - module: PythonTorchTensorRTModule, serialized_engine: bytes, settings: object -) -> PythonTorchTensorRTModule: - return PythonTorchTensorRTModule( - engine=serialized_engine, - input_names=module.input_names, - output_names=module.output_names, - target_device=settings.device, - profiling_enabled=module.profiling_enabled, - ) - - -def copy_cuda_engine(engine: trt.ICudaEngine, runtime: trt.Runtime) -> trt.ICudaEngine: - serialized_engine = engine.serialize() - engine = runtime.deserialize_cuda_engine(serialized_engine) - return engine diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 5646a63108..dfcfbbf48a 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -115,6 +115,7 @@ def convert_module( output_names=list(interpreter_result.output_names), target_device=settings.device, profiling_enabled=settings.debug, + settings=settings, ) else: @@ -131,4 +132,5 @@ def convert_module( output_binding_names=list(interpreter_result.output_names), target_device=settings.device, hardware_compatible=settings.hardware_compatible, + settings=settings, ) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 1fcb765b47..e315e5fe6e 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -34,6 +34,7 @@ def __init__( output_names: Optional[List[str]] = None, target_device: Device = Device._current_device(), profiling_enabled: Optional[bool] = None, + settings: Any = None, ): super(PythonTorchTensorRTModule, self).__init__() self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict) @@ -52,6 +53,7 @@ def __init__( self.profiling_enabled = ( profiling_enabled if profiling_enabled is not None else False ) + self.settings = settings self._initialize() def _initialize(self) -> None: @@ -126,6 +128,13 @@ def __setstate__(self, state: Dict[str, Any]) -> None: if self.engine: self.context = self.engine.create_execution_context() + def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule: + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + result.__setstate__(self.__getstate__()) + return result + def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: with ( torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward") diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 709c10b36e..1302f47e3f 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -44,6 +44,7 @@ def __init__( output_binding_names: Optional[List[str]] = None, target_device: Device = Device._current_device(), hardware_compatible: bool = False, + settings: Any = None, ): """__init__ method for torch_tensorrt.dynamo.runtime._TorchTensorRTModule.TorchTensorRTModule @@ -91,7 +92,7 @@ def __init__( ) self.name = name self.hardware_compatible = hardware_compatible - + self.settings = settings if serialized_engine is not None: self.engine = torch.classes.tensorrt.Engine( [ From 501e5d9e96aa421c3a2da0245984c08dfcdd791d Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 6 Jun 2024 18:17:24 -0700 Subject: [PATCH 13/70] Delete the outdated file --- .../refitting/refit_engine_multi_subgraph.py | 112 ------------------ 1 file changed, 112 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/refitting/refit_engine_multi_subgraph.py diff --git a/py/torch_tensorrt/dynamo/refitting/refit_engine_multi_subgraph.py b/py/torch_tensorrt/dynamo/refitting/refit_engine_multi_subgraph.py deleted file mode 100644 index 3b2febb961..0000000000 --- a/py/torch_tensorrt/dynamo/refitting/refit_engine_multi_subgraph.py +++ /dev/null @@ -1,112 +0,0 @@ -import numpy as np -import torch -import torch_tensorrt as trt -import torchvision.models as models - -# from torch import nn -from torch_tensorrt.dynamo._refit import refit_module_weights - -np.random.seed(0) -torch.manual_seed(0) - - -inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] - -# Small Toy Model --------------------------------------- - - -# class net(nn.Module): -# def __init__(self): -# super().__init__() -# self.conv1 = nn.Conv2d(3, 12, 3, padding=1) -# self.bn = nn.BatchNorm2d(12) -# self.relu = nn.ReLU() - -# def forward(self, x): -# x = self.conv1(x) -# x = self.bn(x) -# x = self.relu(x) -# return x - - -# model = net().eval().to("cuda") -# np.random.seed(1) -# model2 = net().eval().to("cuda") - -# Resnet 18 -------------------------------------------- - -model = models.resnet18(pretrained=False).eval().to("cuda") -model2 = models.resnet18(pretrained=True).eval().to("cuda") - -# Bert ----------------------------------------------- - -# from transformers import BertModel -# inputs = [ -# torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), -# ] -# model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda") - - -enabled_precisions = {torch.float} -debug = True -workspace_size = 20 << 30 -min_block_size = 0 - - -exp_program = torch.export.export(model, tuple(inputs)) -exp_program2 = torch.export.export(model2, tuple(inputs)) - -use_python_runtime = False - -trt_gm = trt.dynamo.compile( - exp_program, - tuple(inputs), - use_python_runtime=use_python_runtime, - enabled_precisions=enabled_precisions, - debug=debug, - min_block_size=min_block_size, - refit=True, - engine_save_path="/home/cehongw/Desktop/torch-trt/TensorRT/py/torch_tensorrt/dynamo/refitting/", -) # Output is a torch.fx.GraphModule - - -output = trt_gm.forward(*inputs) -print(output[0].sum().cpu().item()) -extra_files_loaded = {"settings": None} -path = "/home/cehongw/Desktop/torch-trt/TensorRT/py/torch_tensorrt/dynamo/refitting/trt_compiled_module.ep" - -new_trt_gm = refit_module_weights( - path, - new_weight_module=exp_program2, - inputs=inputs, -) - - -# compiled_exp_program = torch.export.load("/home/cehongw/Desktop/torch-trt/TensorRT/py/torch_tensorrt/dynamo/refitting/trt_compiled_module.ep" -# , extra_files=extra_files_loaded) - - -# decoded = base64.b64decode(extra_files_loaded["settings"].encode('utf-8')) -# restored_settings = pickle.loads(decoded) - -# new_trt_gm = refit_module_weights( -# compiled_module=compiled_exp_program, -# new_weight_module=exp_program2, -# inputs=inputs, -# settings=restored_settings, -# ) - -output_refit = new_trt_gm.forward(*inputs) -print(output_refit[0].sum().cpu().item()) - -pytorch_output = model2(*inputs)[0] -print(pytorch_output.sum().cpu().item()) - -print((output_refit - pytorch_output).mean()) -print() - - -# Iterate over all layers and print weights -# for layer_name in refitter.get_all_weights(): -# # Print kernel weights -# print_layer_weights(refitter, layer_name) From bd5fb55534fbef870a3fd1a24f4f9352b0cf8437 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 6 Jun 2024 18:20:11 -0700 Subject: [PATCH 14/70] Deleted setting loading --- py/torch_tensorrt/dynamo/_refit.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index cbffc511e9..2797735a7c 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -3,7 +3,6 @@ import collections.abc import copy import logging -import pickle import warnings from typing import Any, Sequence, Tuple @@ -136,30 +135,6 @@ def _refit_single_trt_engine_with_gm( def refit_module_weights( - compiled_module_file_path: str, new_weight_module: ExportedProgram, inputs: Any -) -> torch.fx.GraphModule: - """ - Return a copy of compiled_module with refitted weight - """ - settings_wrapper = {"settings": None} - - compiled_exp_program = torch.export.load( - compiled_module_file_path, extra_files=settings_wrapper - ) - - decoded_settings = base64.b64decode(settings_wrapper["settings"].encode("utf-8")) - restored_settings = pickle.loads(decoded_settings) - - new_trt_gm = _refit_module_weights( - compiled_module=compiled_exp_program, - new_weight_module=new_weight_module, - inputs=inputs, - settings=restored_settings, - ) - return new_trt_gm - - -def _refit_module_weights( compiled_module: torch.fx.GraphModule | ExportedProgram, new_weight_module: ExportedProgram, inputs: Tuple[Any, ...], From 578927c783e32409d714defcb506c2253daaebe8 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 7 Jun 2024 13:05:17 -0700 Subject: [PATCH 15/70] Fixed bugs when handling multiple engines. Tested with custom module and Bert --- py/torch_tensorrt/dynamo/_refit.py | 55 ++++++++++++------------------ 1 file changed, 22 insertions(+), 33 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 2797735a7c..536f7cc014 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -3,7 +3,6 @@ import collections.abc import copy import logging -import warnings from typing import Any, Sequence, Tuple import numpy as np @@ -92,9 +91,6 @@ def construct_refit_mapping( layer.get_output_type(0), ) - else: - warnings.warn(f"{layer_type} is not supported yet") - return weight_map @@ -149,17 +145,21 @@ def refit_module_weights( compiled_module = copy.deepcopy(compiled_module) # Get the settings and check the setting to be uniform - if not settings: + if settings is None: for name, submodule in compiled_module.named_children(): + if not isinstance( + submodule, (PythonTorchTensorRTModule, TorchTensorRTModule) + ): + continue if settings is not None: assert settings == submodule.settings settings = submodule.settings - if settings.debug: - set_log_level(logger.parent, logging.DEBUG) + if settings.debug: + set_log_level(logger.parent, logging.DEBUG) - if not isinstance(inputs, collections.abc.Sequence): - inputs = [inputs] + if not isinstance(inputs, collections.abc.Sequence): + inputs = [inputs] # Prepare torch_trt inputs inputs = prepare_inputs(inputs) @@ -219,10 +219,9 @@ def refit_module_weights( # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those - for name, _ in partitioned_module.named_children(): - new_submodule = getattr(partitioned_module, name) + for name, new_submodule in partitioned_module.named_children(): + # Extract engine from the submodule - inline_module = False try: compiled_submodule = getattr(compiled_module, name) if isinstance(compiled_submodule, PythonTorchTensorRTModule): @@ -230,18 +229,17 @@ def refit_module_weights( elif isinstance(compiled_submodule, TorchTensorRTModule): engine_info = compiled_submodule.engine.__getstate__()[0] engine = get_engine_from_encoded_engine(engine_info[3], runtime) - - except AttributeError: - try: - inline_module = True - inline_engine = getattr(compiled_module, f"{name}_engine") - engine_info = inline_engine.__getstate__()[0] - engine = get_engine_from_encoded_engine(engine_info[3], runtime) - - except AttributeError: + elif isinstance(compiled_submodule, torch.fx.graph_module.GraphModule): + # This is graph break resulted by unsupported ops + continue + else: raise AssertionError( - "The type of graph module is not supported for refitting or two compiled modules do not match." + "The type of graph module is not supported for refitting." ) + except AttributeError: + raise AssertionError( + "The type of graph module is not supported for refitting or two compiled modules do not match." + ) # Get the submodule inputs for min, opt, max shapes of the graph inputs submodule_inputs = partitioning.construct_submodule_inputs(new_submodule) @@ -267,23 +265,14 @@ def refit_module_weights( settings=settings, ) - if inline_module: - serialized_engine = bytes(engine.serialize()) - new_engine_info = list(engine_info) - new_engine_info[3] = serialized_engine - refitted_inline_engine = torch.classes.tensorrt.Engine( - tuple(new_engine_info) - ) - setattr(compiled_module, f"{name}_engine", refitted_inline_engine) - - elif isinstance(compiled_submodule, TorchTensorRTModule): + if isinstance(compiled_submodule, TorchTensorRTModule): serialized_engine = bytes(engine.serialize()) new_engine_info = list(engine_info) new_engine_info[3] = serialized_engine refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) compiled_submodule.engine = refitted_engine - return compiled_module + return compiled_module # Util functions ----------- From 82cd2529cb0fa891151a66e7a02ded6ebfdd3130 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 10 Jun 2024 15:44:54 -0700 Subject: [PATCH 16/70] Fixed dtype bugs --- py/torch_tensorrt/dynamo/_refit.py | 32 +++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 536f7cc014..6138d89d3c 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -40,13 +40,13 @@ def construct_refit_mapping( inputs: Sequence[Input], settings: CompilationSettings = CompilationSettings(), ) -> dict[str, np.ndarray]: - """Interpret an FX module to a TRTInterpreterResult + """Find out the weight mapping between weight in exported program and TensorRT engine Args: module: FX GraphModule to interpret inputs: Sequence of Tensors representing inputs to the module settings: Compilation settings Returns: - TRTInterpreterResult + Mapping from weight name in TensorRT to actual weight value in np.ndarray """ MODULE_MAP = { "SCALE": (trt.IScaleLayer, [("scale", "SCALE"), ("shift", "SHIFT")]), @@ -86,9 +86,14 @@ def construct_refit_mapping( layer.__class__ = MODULE_MAP[layer_type][0] for weight_type, weight_name in MODULE_MAP[layer_type][1]: weight = layer.__getattribute__(weight_type).copy() + weight_dtype = ( + layer.precision + if layer.precision_is_set + else convert_numpy_to_tensorrt_dtype(weight.dtype) + ) weight_map[f"{layer.name} {weight_name}"] = ( weight, - layer.get_output_type(0), + weight_dtype, ) return weight_map @@ -285,3 +290,24 @@ def get_engine_from_encoded_engine( serialized_engine = base64.b64decode(encoded_engine) engine = runtime.deserialize_cuda_engine(serialized_engine) return engine + + +def convert_numpy_to_tensorrt_dtype(np_dtype: np.dtypes) -> trt.DataType: + # Define a mapping from numpy dtype to TensorRT dtype + numpy_to_tensorrt_dtype = { + np.dtype("float32"): trt.DataType.FLOAT, + np.float32: trt.DataType.FLOAT, + np.dtype("float16"): trt.DataType.HALF, + np.float16: trt.DataType.HALF, + np.dtype("int32"): trt.DataType.INT32, + np.int32: trt.DataType.INT32, + np.dtype("int64"): trt.DataType.INT64, + np.int64: trt.DataType.INT64, + np.dtype("int8"): trt.DataType.INT8, + np.int8: trt.DataType.INT8, + } + + if np_dtype in numpy_to_tensorrt_dtype: + return numpy_to_tensorrt_dtype[np_dtype] + else: + raise TypeError(f"Unsupported NumPy data type: {np_dtype}") From e3cf823e34bd1eae2b13ad7f4b2b70b6b899006f Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 10 Jun 2024 15:45:36 -0700 Subject: [PATCH 17/70] Made a note to add INormalization Layer --- py/torch_tensorrt/dynamo/_refit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 6138d89d3c..4fd7debce1 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -59,6 +59,7 @@ def construct_refit_mapping( [("kernel", "KERNEL"), ("bias", "BIAS")], ), "CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]), + # TODO: Add INormalizationLayer } output_dtypes = infer_module_output_dtypes( From c3b0862de7356f41c2be4f58fc4708995cc09308 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 11 Jun 2024 11:08:20 -0700 Subject: [PATCH 18/70] Update the unsupported torch module weights --- py/torch_tensorrt/dynamo/_refit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 4fd7debce1..a75d3c3bc6 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -237,6 +237,7 @@ def refit_module_weights( engine = get_engine_from_encoded_engine(engine_info[3], runtime) elif isinstance(compiled_submodule, torch.fx.graph_module.GraphModule): # This is graph break resulted by unsupported ops + compiled_submodule.load_state_dict(new_submodule.state_dict()) continue else: raise AssertionError( From 400bcacf1857cedabda91bf9df6ce0617981f84f Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 12 Jun 2024 12:17:58 -0700 Subject: [PATCH 19/70] Cleaned up the code. Added refitting outcome check --- py/torch_tensorrt/dynamo/_refit.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index a75d3c3bc6..e0be0c65b8 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -59,7 +59,6 @@ def construct_refit_mapping( [("kernel", "KERNEL"), ("bias", "BIAS")], ), "CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]), - # TODO: Add INormalizationLayer } output_dtypes = infer_module_output_dtypes( @@ -127,14 +126,12 @@ def _refit_single_trt_engine_with_gm( refitted.add(layer_name) if len(refitted) != len(weight_list): - raise AssertionError("Not all weights have been refitted") + logger.warning("Not all weights have been refitted!!!") if not refitter.refit_cuda_engine(): - print("Error: failed to refit new weights.") + logger.error("Error: failed to refit new weights.") exit(0) - print("Refit Successful") - def refit_module_weights( compiled_module: torch.fx.GraphModule | ExportedProgram, @@ -145,6 +142,7 @@ def refit_module_weights( """ Refit a compiled graph module with ExportedProgram """ + raw_inputs = copy.deepcopy(inputs) if isinstance(compiled_module, ExportedProgram): compiled_module = compiled_module.module() @@ -227,6 +225,8 @@ def refit_module_weights( # Generate the corresponding TRT Module for those for name, new_submodule in partitioned_module.named_children(): + # Refit each submodule + # Extract engine from the submodule try: compiled_submodule = getattr(compiled_module, name) @@ -279,9 +279,28 @@ def refit_module_weights( refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) compiled_submodule.engine = refitted_engine + check_output( + new_submodule=new_submodule, + compiled_submodule=compiled_submodule, + inputs=raw_inputs, + ) + logger.info("Refit Successful!") return compiled_module +def check_output( + new_submodule: torch.fx.GraphModule, + compiled_submodule: torch.fx.GraphModule, + inputs: Tuple[Any, ...], +) -> None: + # inputs = [t.contiguous() for t in inputs] + old_outputs, new_outputs = compiled_submodule(*inputs), new_submodule(*inputs) + for old_output, new_output in zip(old_outputs, new_outputs): + assert torch.allclose( + old_output, new_output, 1e-2, 1e-2 + ), "Refit Result is not correct. Refit failed" + + # Util functions ----------- import base64 From 6f086644aabd02ce69025ca7f718afb332542311 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 12 Jun 2024 13:55:32 -0700 Subject: [PATCH 20/70] Use enums to change dtype from np to trt --- py/torch_tensorrt/dynamo/_refit.py | 24 ++---------------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index e0be0c65b8..d3007f52d0 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -9,6 +9,7 @@ import tensorrt as trt import torch from torch.export import ExportedProgram +from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import partitioning from torch_tensorrt.dynamo.conversion import CompilationSettings @@ -89,7 +90,7 @@ def construct_refit_mapping( weight_dtype = ( layer.precision if layer.precision_is_set - else convert_numpy_to_tensorrt_dtype(weight.dtype) + else dtype.try_from(weight.dtype).to(trt.DataType) ) weight_map[f"{layer.name} {weight_name}"] = ( weight, @@ -311,24 +312,3 @@ def get_engine_from_encoded_engine( serialized_engine = base64.b64decode(encoded_engine) engine = runtime.deserialize_cuda_engine(serialized_engine) return engine - - -def convert_numpy_to_tensorrt_dtype(np_dtype: np.dtypes) -> trt.DataType: - # Define a mapping from numpy dtype to TensorRT dtype - numpy_to_tensorrt_dtype = { - np.dtype("float32"): trt.DataType.FLOAT, - np.float32: trt.DataType.FLOAT, - np.dtype("float16"): trt.DataType.HALF, - np.float16: trt.DataType.HALF, - np.dtype("int32"): trt.DataType.INT32, - np.int32: trt.DataType.INT32, - np.dtype("int64"): trt.DataType.INT64, - np.int64: trt.DataType.INT64, - np.dtype("int8"): trt.DataType.INT8, - np.int8: trt.DataType.INT8, - } - - if np_dtype in numpy_to_tensorrt_dtype: - return numpy_to_tensorrt_dtype[np_dtype] - else: - raise TypeError(f"Unsupported NumPy data type: {np_dtype}") From 2250239d2ce53b964f4569b2c915db35dbf8668b Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 12 Jun 2024 16:32:59 -0700 Subject: [PATCH 21/70] Moved check output to util and added a gated flag --- py/torch_tensorrt/dynamo/_refit.py | 47 ++++++++++++------------------ py/torch_tensorrt/dynamo/utils.py | 13 +++++++++ 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index d3007f52d0..3126363e60 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -25,6 +25,7 @@ ) from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule from torch_tensorrt.dynamo.utils import ( + check_output, get_torch_inputs, prepare_inputs, set_log_level, @@ -138,7 +139,7 @@ def refit_module_weights( compiled_module: torch.fx.GraphModule | ExportedProgram, new_weight_module: ExportedProgram, inputs: Tuple[Any, ...], - settings: Any = None, + verify_output: bool = True, ) -> torch.fx.GraphModule: """ Refit a compiled graph module with ExportedProgram @@ -150,15 +151,13 @@ def refit_module_weights( compiled_module = copy.deepcopy(compiled_module) # Get the settings and check the setting to be uniform - if settings is None: - for name, submodule in compiled_module.named_children(): - if not isinstance( - submodule, (PythonTorchTensorRTModule, TorchTensorRTModule) - ): - continue - if settings is not None: - assert settings == submodule.settings - settings = submodule.settings + settings: Any = None + for name, submodule in compiled_module.named_children(): + if not isinstance(submodule, (PythonTorchTensorRTModule, TorchTensorRTModule)): + continue + if settings is not None: + assert settings == submodule.settings + settings = submodule.settings if settings.debug: set_log_level(logger.parent, logging.DEBUG) @@ -280,28 +279,18 @@ def refit_module_weights( refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) compiled_submodule.engine = refitted_engine - check_output( - new_submodule=new_submodule, - compiled_submodule=compiled_submodule, - inputs=raw_inputs, - ) - logger.info("Refit Successful!") + if verify_output: + check_output( + new_submodule=new_submodule, + compiled_submodule=compiled_submodule, + inputs=raw_inputs, + ) + logger.info("Refit Successful!") + else: + logger.info("Refit Completed! Output verification skipped.") return compiled_module -def check_output( - new_submodule: torch.fx.GraphModule, - compiled_submodule: torch.fx.GraphModule, - inputs: Tuple[Any, ...], -) -> None: - # inputs = [t.contiguous() for t in inputs] - old_outputs, new_outputs = compiled_submodule(*inputs), new_submodule(*inputs) - for old_output, new_output in zip(old_outputs, new_outputs): - assert torch.allclose( - old_output, new_output, 1e-2, 1e-2 - ), "Refit Result is not correct. Refit failed" - - # Util functions ----------- import base64 diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 6ea9503b84..adfb3902e4 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -312,3 +312,16 @@ def function_wrapper(*args: Any, **kwargs: Any) -> Any: return function_wrapper return nested_decorator + + +def check_output( + new_submodule: torch.fx.GraphModule, + compiled_submodule: torch.fx.GraphModule, + inputs: tuple[Any, ...], +) -> None: + # inputs = [t.contiguous() for t in inputs] + old_outputs, new_outputs = compiled_submodule(*inputs), new_submodule(*inputs) + for old_output, new_output in zip(old_outputs, new_outputs): + assert torch.allclose( + old_output, new_output, 1e-2, 1e-2 + ), "Refit Result is not correct. Refit failed" From c906d0e2321a2e11231c2e67277d997d06c869d3 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 13 Jun 2024 15:31:08 -0700 Subject: [PATCH 22/70] fixed a bug in check_output function. Changed to only check once after all refitting --- py/torch_tensorrt/dynamo/utils.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index adfb3902e4..89caaf55d0 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -315,13 +315,16 @@ def function_wrapper(*args: Any, **kwargs: Any) -> Any: def check_output( - new_submodule: torch.fx.GraphModule, - compiled_submodule: torch.fx.GraphModule, + new_module: torch.fx.GraphModule, + refitted_module: torch.fx.GraphModule, inputs: tuple[Any, ...], ) -> None: # inputs = [t.contiguous() for t in inputs] - old_outputs, new_outputs = compiled_submodule(*inputs), new_submodule(*inputs) + old_outputs, new_outputs = refitted_module(*inputs), new_module(*inputs) for old_output, new_output in zip(old_outputs, new_outputs): - assert torch.allclose( - old_output, new_output, 1e-2, 1e-2 - ), "Refit Result is not correct. Refit failed" + if isinstance(old_output, torch.tensor) and isinstance( + new_outputs, torch.tensor + ): + assert torch.allclose( + old_output, new_output, 1e-2, 1e-2 + ), "Refit Result is not correct. Refit failed" From 51cba6f4e12b95f249507c00d86f3b3fb8d31ee9 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 13 Jun 2024 15:37:27 -0700 Subject: [PATCH 23/70] reverse the main function --- .../dynamo/refitting/refit_engine.py | 106 ------------------ 1 file changed, 106 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/refitting/refit_engine.py diff --git a/py/torch_tensorrt/dynamo/refitting/refit_engine.py b/py/torch_tensorrt/dynamo/refitting/refit_engine.py deleted file mode 100644 index b3b42b6e5b..0000000000 --- a/py/torch_tensorrt/dynamo/refitting/refit_engine.py +++ /dev/null @@ -1,106 +0,0 @@ -import numpy as np -import torch -import torchvision.models as models -from torch_tensorrt._Device import Device -from torch_tensorrt.dynamo._compiler import convert_module_to_trt_engine -from torch_tensorrt.dynamo._refit import refit_single_trt_engine_with_ep -from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( - PythonTorchTensorRTModule, -) - -np.random.seed(0) -torch.manual_seed(0) - - -inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] - - -# class net(nn.Module): -# def __init__(self): -# super().__init__() -# self.conv1 = nn.Conv2d(3, 12, 3, padding=1) -# self.bn = nn.BatchNorm2d(12) -# self.relu = nn.ReLU() - -# def forward(self, x): -# x = self.conv1(x) -# x = self.bn(x) -# x = self.relu(x) -# return x - - -# model = net().eval().to("cuda") -# np.random.seed(1) -# model2 = net().eval().to("cuda") - -model = models.resnet18(pretrained=False).eval().to("cuda") -model2 = models.resnet18(pretrained=True).eval().to("cuda") -enabled_precisions = {torch.float} -debug = True -workspace_size = 20 << 30 -min_block_size = 1 - - -exp_program = torch.export.export(model, tuple(inputs)) -exp_program2 = torch.export.export(model2, tuple(inputs)) - - -serialized_engine = convert_module_to_trt_engine( - exported_program=exp_program, - inputs=tuple(inputs), - enabled_precisions=enabled_precisions, - debug=debug, - min_block_size=min_block_size, - refit=True, -) - -trt_module = PythonTorchTensorRTModule( - engine=serialized_engine, - input_names=["x"], - output_names=["output0"], - target_device=Device._current_device(), - profiling_enabled=False, -) - -output = trt_module.forward(*inputs) -print(output[0].sum().cpu().item()) -engine = trt_module.engine -print(model(*inputs)[0].sum().cpu().item()) - -# ----------------------Refitting------------------------------------ -# weights_to_be_fitted = model2.state_dict() - - -# refit_dict = { -# '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv1/convolution] BIAS': weights_to_be_fitted['conv1.bias'] -# , -# '[SCALE]-[aten_ops._native_batch_norm_legit_no_training.default]-[/bn/_native_batch_norm_legit_no_training] SCALE': weights_to_be_fitted['bn.weight'] -# , -# '[SCALE]-[aten_ops._native_batch_norm_legit_no_training.default]-[/bn/_native_batch_norm_legit_no_training] SHIFT': weights_to_be_fitted['bn.bias'] -# , -# '[CONVOLUTION]-[aten_ops.convolution.default]-[/conv1/convolution] KERNEL': weights_to_be_fitted['conv1.weight'] -# } - - -refit_single_trt_engine_with_ep( - exported_program=exp_program2, # New - inputs=tuple(inputs), - engine=engine, # Old - enabled_precisions=enabled_precisions, - debug=debug, - min_block_size=min_block_size, -) - -output = trt_module.forward(*inputs) -print(output[0].sum().cpu().item()) -engine = trt_module.engine -pytorch_output = model2(*inputs)[0] -print(pytorch_output.sum().cpu().item()) -print((output - pytorch_output).mean()) -print() - - -# Iterate over all layers and print weights -# for layer_name in refitter.get_all_weights(): -# # Print kernel weights -# print_layer_weights(refitter, layer_name) From e3576fa001bcdc50915717d2073ece9d723b024f Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 13 Jun 2024 16:07:59 -0700 Subject: [PATCH 24/70] Added support to inline module w/ or w/o graph break --- py/torch_tensorrt/dynamo/_refit.py | 131 ++++++++++++++++++++--------- 1 file changed, 90 insertions(+), 41 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 3126363e60..361aa8563b 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -3,6 +3,7 @@ import collections.abc import copy import logging +import pickle from typing import Any, Sequence, Tuple import numpy as np @@ -12,6 +13,7 @@ from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import partitioning +from torch_tensorrt.dynamo._exporter import inline_torch_modules from torch_tensorrt.dynamo.conversion import CompilationSettings from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( @@ -144,20 +146,40 @@ def refit_module_weights( """ Refit a compiled graph module with ExportedProgram """ + inline_module = False raw_inputs = copy.deepcopy(inputs) if isinstance(compiled_module, ExportedProgram): + inline_module = True compiled_module = compiled_module.module() compiled_module = copy.deepcopy(compiled_module) # Get the settings and check the setting to be uniform settings: Any = None - for name, submodule in compiled_module.named_children(): - if not isinstance(submodule, (PythonTorchTensorRTModule, TorchTensorRTModule)): - continue - if settings is not None: - assert settings == submodule.settings - settings = submodule.settings + if inline_module: + + # Obtain the settings + compiled_submodules = [ + (name.replace("_engine", ""), engine) + for name, engine in compiled_module.__dict__.items() + if "engine" in name + ] + encoded_settings = compiled_submodules[0][1].__getstate__()[0][7] + settings = get_settings(encoded_settings) + # Handle torch modules + compiled_submodules_map = dict(compiled_submodules) + for name, submodule in compiled_module.named_children(): + compiled_submodules_map[name] = submodule + + else: + for name, submodule in compiled_module.named_children(): + if not isinstance( + submodule, (PythonTorchTensorRTModule, TorchTensorRTModule) + ): + continue + if settings is not None: + assert settings == submodule.settings + settings = submodule.settings if settings.debug: set_log_level(logger.parent, logging.DEBUG) @@ -176,11 +198,11 @@ def refit_module_weights( new_weight_module = new_weight_module.run_decompositions( get_decompositions(settings.enable_experimental_decompositions) ) - gm = new_weight_module.module() - logger.debug("Input graph: " + str(gm.graph)) + new_gm = new_weight_module.module() + logger.debug("Input graph: " + str(new_gm.graph)) # Apply lowering on the graph module torch_inputs = get_torch_inputs(inputs, device) - gm = apply_lowering_passes(gm, torch_inputs) + new_gm = apply_lowering_passes(new_gm, torch_inputs) logger.info("Compilation Settings: %s\n", settings) @@ -190,8 +212,8 @@ def refit_module_weights( # If specified, try using the fast partitioner and fall back to the global one on failure if settings.use_fast_partitioner: try: - partitioned_module, supported_ops = partitioning.fast_partition( - gm, + new_partitioned_module, supported_ops = partitioning.fast_partition( + new_gm, verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, @@ -207,42 +229,62 @@ def refit_module_weights( settings.use_fast_partitioner = False if not settings.use_fast_partitioner: - partitioned_module, supported_ops = partitioning.global_partition( - gm, + new_partitioned_module, supported_ops = partitioning.global_partition( + new_gm, verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, ) + if inline_module: + # Preprocess the partitioned module to be in the same format as the inline module + inline_torch_modules(new_partitioned_module) + new_partitioned_module.delete_all_unused_submodules() + # Check whether two modules have the same subcomponents # 1. Check the number of partitions and name - assert [sm[0] for sm in partitioned_module.named_children()] == [ - sm[0] for sm in compiled_module.named_children() - ] + if inline_module: + assert {sm[0] for sm in new_partitioned_module.named_children()} == set( + compiled_submodules_map.keys() + ), "The compiled module is incompatible with the new module!" + else: + assert {sm[0] for sm in new_partitioned_module.named_children()} == { + sm[0] for sm in compiled_module.named_children() + }, "The compiled module is incompatible with the new module!" # 2. TODO: Check the hash of source fx.Graph and new fx.Graph # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those - for name, new_submodule in partitioned_module.named_children(): - # Refit each submodule + for name, new_submodule in new_partitioned_module.named_children(): + # Refit each submodule # Extract engine from the submodule try: - compiled_submodule = getattr(compiled_module, name) - if isinstance(compiled_submodule, PythonTorchTensorRTModule): - engine = compiled_submodule.engine - elif isinstance(compiled_submodule, TorchTensorRTModule): - engine_info = compiled_submodule.engine.__getstate__()[0] - engine = get_engine_from_encoded_engine(engine_info[3], runtime) - elif isinstance(compiled_submodule, torch.fx.graph_module.GraphModule): - # This is graph break resulted by unsupported ops - compiled_submodule.load_state_dict(new_submodule.state_dict()) - continue + if inline_module: + compiled_submodule = compiled_submodules_map[name] + # If this is a torch module, load the old state_dict + if "_run_on_acc" not in name: + compiled_submodule.load_state_dict(new_submodule.state_dict()) + continue + else: + engine_info = compiled_submodule.__getstate__()[0] + engine = get_engine_from_encoded_engine(engine_info[3], runtime) else: - raise AssertionError( - "The type of graph module is not supported for refitting." - ) + compiled_submodule = getattr(compiled_module, name) + if isinstance(compiled_submodule, PythonTorchTensorRTModule): + engine = compiled_submodule.engine + elif isinstance(compiled_submodule, TorchTensorRTModule): + engine_info = compiled_submodule.engine.__getstate__()[0] + engine = get_engine_from_encoded_engine(engine_info[3], runtime) + elif isinstance(compiled_submodule, torch.fx.graph_module.GraphModule): + # This is graph break resulted by unsupported ops + compiled_submodule.load_state_dict(new_submodule.state_dict()) + continue + else: + raise AssertionError( + "The type of graph module is not supported for refitting." + ) except AttributeError: raise AssertionError( "The type of graph module is not supported for refitting or two compiled modules do not match." @@ -258,7 +300,7 @@ def refit_module_weights( # Handle long/double inputs if requested by the user if settings.truncate_double: submodule_inputs = repair_double_inputs( - partitioned_module, + new_partitioned_module, new_submodule, submodule_inputs, to_torch_device(settings.device), @@ -279,15 +321,16 @@ def refit_module_weights( refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) compiled_submodule.engine = refitted_engine - if verify_output: - check_output( - new_submodule=new_submodule, - compiled_submodule=compiled_submodule, - inputs=raw_inputs, - ) - logger.info("Refit Successful!") - else: - logger.info("Refit Completed! Output verification skipped.") + if verify_output: + check_output( + new_module=new_gm, + refitted_module=compiled_module, + inputs=raw_inputs, + ) + logger.info("Refit Successful!") + else: + logger.info("Refit Completed! Output verification skipped.") + return compiled_module @@ -301,3 +344,9 @@ def get_engine_from_encoded_engine( serialized_engine = base64.b64decode(encoded_engine) engine = runtime.deserialize_cuda_engine(serialized_engine) return engine + + +def get_settings(encoded_settings: bytes) -> Any: + dumped_settings = base64.b64decode(encoded_settings.encode("utf-8")) + settings = pickle.loads(dumped_settings) + return settings From cde8fe9b6b0263915e074a723950374b660f05c1 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 13 Jun 2024 16:08:34 -0700 Subject: [PATCH 25/70] Added an extra attribute to TRT engine in cpp --- core/runtime/TRTEngine.cpp | 6 +++++- core/runtime/TRTEngine.h | 4 +++- core/runtime/register_jit_hooks.cpp | 2 +- core/runtime/runtime.h | 1 + 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 023f54c113..0065585a72 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -33,6 +33,7 @@ TRTEngine::TRTEngine( const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, + const std::string& serialized_settings, bool hardware_compatible) : TRTEngine( "deserialized_trt", @@ -40,6 +41,7 @@ TRTEngine::TRTEngine( cuda_device, _in_binding_names, _out_binding_names, + serialized_settings, hardware_compatible) {} TRTEngine::TRTEngine(std::vector serialized_info) @@ -49,6 +51,7 @@ TRTEngine::TRTEngine(std::vector serialized_info) RTDevice(serialized_info[DEVICE_IDX]), split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM), split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM), + serialized_info[SETTING_IDX], static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX]))) {} TRTEngine::TRTEngine( @@ -57,9 +60,10 @@ TRTEngine::TRTEngine( const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, + const std::string& serialized_settings, bool hardware_compatible) { this->hardware_compatible = hardware_compatible; - + this->serialized_settings = serialized_settings; auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible); TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine"); device_info = most_compatible_device.value(); diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 7960d04b46..c8f57041bd 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -35,13 +35,14 @@ struct TRTEngine : torch::CustomClassHolder { std::vector out_binding_names = {}; // ITO: PYT IDX bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode - + std::string serialized_settings; ~TRTEngine(); TRTEngine( const std::string& serialized_engine, const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, + const std::string& serialized_settings, bool hardware_compatible = false); TRTEngine(std::vector serialized_info); TRTEngine( @@ -50,6 +51,7 @@ struct TRTEngine : torch::CustomClassHolder { const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, + const std::string& serialized_settings, bool hardware_compatible = false); TRTEngine& operator=(const TRTEngine& other); std::string to_str() const; diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 901923ce20..81277ee0a6 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -102,7 +102,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = serialize_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(self->in_binding_names); serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names); serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0"; - + serialize_info[SETTING_IDX] = self->serialized_settings; LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled")); return serialize_info; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 8c9b33328a..20b2c7335e 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -25,6 +25,7 @@ typedef enum { INPUT_BINDING_NAMES_IDX, OUTPUT_BINDING_NAMES_IDX, HW_COMPATIBLE_IDX, + SETTING_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; From b575105d4480d8c8fb2c7e7716260211caeec39a Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 13 Jun 2024 16:11:31 -0700 Subject: [PATCH 26/70] Added an attribute in TorchTRTModule in python. --- py/torch_tensorrt/dynamo/_refit.py | 4 ++-- .../dynamo/runtime/_TorchTensorRTModule.py | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 361aa8563b..e8fa335b6b 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -339,14 +339,14 @@ def refit_module_weights( def get_engine_from_encoded_engine( - encoded_engine: bytes, runtime: trt.Runtime + encoded_engine: str, runtime: trt.Runtime ) -> trt.ICudaEngine: serialized_engine = base64.b64decode(encoded_engine) engine = runtime.deserialize_cuda_engine(serialized_engine) return engine -def get_settings(encoded_settings: bytes) -> Any: +def get_settings(encoded_settings: str) -> Any: dumped_settings = base64.b64decode(encoded_settings.encode("utf-8")) settings = pickle.loads(dumped_settings) return settings diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 1302f47e3f..0a67e12cc7 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -1,6 +1,8 @@ from __future__ import annotations +import base64 import logging +import pickle from typing import Any, List, Optional, Tuple import torch @@ -9,7 +11,7 @@ logger = logging.getLogger(__name__) SerializedTensorRTEngineFmt = Tuple[ - str, str, str, bytes, str, str, str + str, str, str, bytes, str, str, bytes, str ] # Defined in //core/runtime/register_jit_hooks.cpp SerializedTorchTensorRTModuleFmt = Tuple[ str, Optional[SerializedTensorRTEngineFmt], List[str], List[str] @@ -103,11 +105,17 @@ def __init__( TorchTensorRTModule._pack_binding_names(self.input_binding_names), TorchTensorRTModule._pack_binding_names(self.output_binding_names), str(int(hardware_compatible)), + self.encode_settings(settings), ] ) else: self.engine = None + def encode_settings(self, settings: Any) -> str: + dumped_settings = pickle.dumps(settings) + encoded_settings = base64.b64encode(dumped_settings).decode("utf-8") + return encoded_settings + def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt: return ( self.name, @@ -120,7 +128,6 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: self.name = state[0] if state[1] is not None: serialized_engine_info: SerializedTensorRTEngineFmt = state[1] - import base64 serialized_engine = base64.b64decode(serialized_engine_info[3]) self.engine = torch.classes.tensorrt.Engine( @@ -132,6 +139,7 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: serialized_engine_info[4], serialized_engine_info[5], serialized_engine_info[6], + serialized_engine_info[7], ] ) else: From 9923125ce1948b72c552d7dfb5f42eb0efabd16a Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 13 Jun 2024 16:13:31 -0700 Subject: [PATCH 27/70] Fixed a type --- py/torch_tensorrt/dynamo/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 89caaf55d0..36d208c52f 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -322,8 +322,8 @@ def check_output( # inputs = [t.contiguous() for t in inputs] old_outputs, new_outputs = refitted_module(*inputs), new_module(*inputs) for old_output, new_output in zip(old_outputs, new_outputs): - if isinstance(old_output, torch.tensor) and isinstance( - new_outputs, torch.tensor + if isinstance(old_output, torch.Tensor) and isinstance( + new_outputs, torch.Tensor ): assert torch.allclose( old_output, new_output, 1e-2, 1e-2 From e25941e6cc2c7a910d73e99f4356e17f41a6568c Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 13 Jun 2024 17:32:22 -0700 Subject: [PATCH 28/70] Fixed a bug for inline_module refit --- py/torch_tensorrt/dynamo/_refit.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index e8fa335b6b..f9c5390347 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -321,13 +321,20 @@ def refit_module_weights( refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) compiled_submodule.engine = refitted_engine + elif inline_module: + serialized_engine = bytes(engine.serialize()) + new_engine_info = list(engine_info) + new_engine_info[3] = serialized_engine + refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) + setattr(compiled_module, f"{name}_engine", refitted_engine) + if verify_output: check_output( new_module=new_gm, refitted_module=compiled_module, inputs=raw_inputs, ) - logger.info("Refit Successful!") + logger.info("Refit Successfully!") else: logger.info("Refit Completed! Output verification skipped.") From 646da9e5607d4bea8faea657e81706b8eca2717d Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 13 Jun 2024 17:41:34 -0700 Subject: [PATCH 29/70] Added refit example documentation --- examples/dynamo/refit_engine_example.py | 136 ++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 examples/dynamo/refit_engine_example.py diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py new file mode 100644 index 0000000000..a6b41f895c --- /dev/null +++ b/examples/dynamo/refit_engine_example.py @@ -0,0 +1,136 @@ +""" +.. _refit_engine_example: + +Refit TenorRT Graph Module with Torch-TensorRT +=================================================================== + +We are going to demonstrate how a compiled TensorRT Graph Module can be refitted with updated weights. + +In many cases, we frequently update the weights of models, such as applying various LoRA to Stable Diffusion or constant A/B testing of AI products. +That poses challenges for TensorRT inference optimizations, as compiling the TensorRT engines takes significant time, making repetitive compilation highly inefficient. +Torch-TensorRT supports refitting TensorRT graph modules without re-compiling the engine, considerably accelerating the workflow. + +In this tutorial, we are going to walk through +1. Compiling a PyTorch model to a TensorRT Graph Module +2. Save and load a graph module +3. Refit the graph module +""" + +# %% +# Standard Workflow +# ----------------------------- + +# %% +# Imports and model definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import numpy as np +import torch +import torch.nn.functional as F +import torch_tensorrt as trt +import torchvision.models as models +from torch import nn +from torch_tensorrt.dynamo._refit import refit_module_weights + +np.random.seed(0) +torch.manual_seed(0) +inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] + + +# %% +# Compile the module for the first time and save it. +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +model = models.resnet18(pretrained=False).eval().to("cuda") +exp_program = torch.export.export(model, tuple(inputs)) +enabled_precisions = {torch.float} +debug = False +workspace_size = 20 << 30 +min_block_size = 0 +use_python_runtime = False +torch_executed_ops = {} +trt_gm = trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, + refit=True, +) # Output is a torch.fx.GraphModule + +# Save the graph module as an exported program +trt.save(trt_gm, "./compiled.ep", inputs=inputs) + + +# %% +# Refit the module with update model weights +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Create and compile the updated model +model2 = models.resnet18(pretrained=True).eval().to("cuda") +exp_program2 = torch.export.export(model2, tuple(inputs)) + + +compiled_trt_gm = trt.load("./compiled.ep") + +# This returns a new module with updated weights +new_trt_gm = refit_module_weights( + compiled_module=compiled_trt_gm, + new_weight_module=exp_program2, + inputs=inputs, +) + +# Check the output +expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs) +for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assert torch.allclose( + expected_output, refitted_output, 1e-2, 1e-2 + ), "Refit Result is not correct. Refit failed" + +print("Refit successfully!") + +# %% +# Alterative Workflow using Python Runtime +# ----------------------------- + +# Currently python runtime does not support engine serialization. So the refitt will be done in the same runtime. +# This usecase is more useful when you need to switch different weights during runtime, such as using Stable Diffusion. + +model = models.resnet18(pretrained=False).eval().to("cuda") +exp_program = torch.export.export(model, tuple(inputs)) +enabled_precisions = {torch.float} +debug = False +workspace_size = 20 << 30 +min_block_size = 0 +use_python_runtime = True +torch_executed_ops = {} +trt_gm = trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, + refit=True, +) + +model2 = models.resnet18(pretrained=True).eval().to("cuda") +exp_program2 = torch.export.export(model2, tuple(inputs)) + +new_trt_gm = refit_module_weights( + compiled_module=compiled_trt_gm, + new_weight_module=exp_program2, + inputs=inputs, +) + +# Check the output +expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs) +for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assert torch.allclose( + expected_output, refitted_output, 1e-2, 1e-2 + ), "Refit Result is not correct. Refit failed" + +print("Refit successfully!") From 924f4a8fd0c1ed3fa633ffdf44af508d7b755b93 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 13 Jun 2024 18:19:32 -0700 Subject: [PATCH 30/70] Added backward compatibility --- core/runtime/TRTEngine.cpp | 51 ++++++++++++++++++++++++++++++++++++-- core/runtime/TRTEngine.h | 14 +++++++++++ 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 0065585a72..a330abceb6 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -28,6 +28,21 @@ std::vector split(const std::string& str, char delim) { return strings; } +TRTEngine::TRTEngine( + const std::string& serialized_engine, + const RTDevice& cuda_device, + const std::vector& _in_binding_names, + const std::vector& _out_binding_names, + bool hardware_compatible) + : TRTEngine( + "deserialized_trt", + serialized_engine, + cuda_device, + _in_binding_names, + _out_binding_names, + "", + hardware_compatible) {} + TRTEngine::TRTEngine( const std::string& serialized_engine, const RTDevice& cuda_device, @@ -44,7 +59,7 @@ TRTEngine::TRTEngine( serialized_settings, hardware_compatible) {} -TRTEngine::TRTEngine(std::vector serialized_info) +TRTEngine::TRTEngine(std::vector serialized_info) try : TRTEngine( serialized_info[NAME_IDX], serialized_info[ENGINE_IDX], @@ -52,7 +67,27 @@ TRTEngine::TRTEngine(std::vector serialized_info) split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM), split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM), serialized_info[SETTING_IDX], - static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX]))) {} + static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX]))) { +} catch (const std::exception& e) { + std::cerr << "No compilation settings is passed in" << std::endl; + handleConstructorError(serialized_info); +} + +TRTEngine::TRTEngine( + const std::string& mod_name, + const std::string& serialized_engine, + const RTDevice& cuda_device, + const std::vector& _in_binding_names, + const std::vector& _out_binding_names, + bool hardware_compatible) + : TRTEngine( + mod_name, + serialized_engine, + cuda_device, + _in_binding_names, + _out_binding_names, + "", + hardware_compatible) {} TRTEngine::TRTEngine( const std::string& mod_name, @@ -190,6 +225,18 @@ TRTEngine::TRTEngine( LOG_DEBUG(*this); } +void TRTEngine::handleConstructorError(std::vector serialized_info) { + // Delegate to the fallback constructor + *this = TRTEngine( + serialized_info[NAME_IDX], + serialized_info[ENGINE_IDX], + RTDevice(serialized_info[DEVICE_IDX]), + split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM), + split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM), + "", + static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX]))); +} + TRTEngine::~TRTEngine() { trt_engine_profiler.reset(); exec_ctx.reset(); diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index c8f57041bd..8520f621ba 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -37,6 +37,12 @@ struct TRTEngine : torch::CustomClassHolder { bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode std::string serialized_settings; ~TRTEngine(); + TRTEngine( + const std::string& serialized_engine, + const RTDevice& cuda_device, + const std::vector& in_binding_names, + const std::vector& out_binding_names, + bool hardware_compatible = false); TRTEngine( const std::string& serialized_engine, const RTDevice& cuda_device, @@ -45,6 +51,13 @@ struct TRTEngine : torch::CustomClassHolder { const std::string& serialized_settings, bool hardware_compatible = false); TRTEngine(std::vector serialized_info); + TRTEngine( + const std::string& mod_name, + const std::string& serialized_engine, + const RTDevice& cuda_device, + const std::vector& in_binding_names, + const std::vector& out_binding_names, + bool hardware_compatible = false); TRTEngine( const std::string& mod_name, const std::string& serialized_engine, @@ -53,6 +66,7 @@ struct TRTEngine : torch::CustomClassHolder { const std::vector& out_binding_names, const std::string& serialized_settings, bool hardware_compatible = false); + void handleConstructorError(std::vector serialized_info); TRTEngine& operator=(const TRTEngine& other); std::string to_str() const; static void verify_serialization_fmt(const std::vector& serialized_info); From 3c25a3a4138fafdf303b3a8c1e71da5b6dd0992a Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 13 Jun 2024 18:47:15 -0700 Subject: [PATCH 31/70] Rename the setting enum --- core/runtime/TRTEngine.cpp | 2 +- core/runtime/register_jit_hooks.cpp | 2 +- core/runtime/runtime.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index a330abceb6..f1d5a33fe3 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -66,7 +66,7 @@ TRTEngine::TRTEngine(std::vector serialized_info) try RTDevice(serialized_info[DEVICE_IDX]), split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM), split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM), - serialized_info[SETTING_IDX], + serialized_info[SERIALIZED_COMPILE_SETTINGS_IDX], static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX]))) { } catch (const std::exception& e) { std::cerr << "No compilation settings is passed in" << std::endl; diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 81277ee0a6..be0b5216d0 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -102,7 +102,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = serialize_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(self->in_binding_names); serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names); serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0"; - serialize_info[SETTING_IDX] = self->serialized_settings; + serialize_info[SERIALIZED_COMPILE_SETTINGS_IDX] = self->serialized_settings; LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled")); return serialize_info; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 20b2c7335e..4ac755151e 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -25,7 +25,7 @@ typedef enum { INPUT_BINDING_NAMES_IDX, OUTPUT_BINDING_NAMES_IDX, HW_COMPATIBLE_IDX, - SETTING_IDX, + SERIALIZED_COMPILE_SETTINGS_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; From 0c9637dfdbd7acc87d383fc6b778c3284f3bed69 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Jun 2024 10:43:59 -0700 Subject: [PATCH 32/70] Cleaned up cpp constructors --- core/runtime/TRTEngine.cpp | 65 ++++++-------------------------------- core/runtime/TRTEngine.h | 22 +++---------- 2 files changed, 13 insertions(+), 74 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index f1d5a33fe3..168058a927 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -33,45 +33,26 @@ TRTEngine::TRTEngine( const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, - bool hardware_compatible) + bool hardware_compatible, + const std::string& serialized_settings) : TRTEngine( "deserialized_trt", serialized_engine, cuda_device, _in_binding_names, _out_binding_names, - "", - hardware_compatible) {} + hardware_compatible, + serialized_settings) {} -TRTEngine::TRTEngine( - const std::string& serialized_engine, - const RTDevice& cuda_device, - const std::vector& _in_binding_names, - const std::vector& _out_binding_names, - const std::string& serialized_settings, - bool hardware_compatible) - : TRTEngine( - "deserialized_trt", - serialized_engine, - cuda_device, - _in_binding_names, - _out_binding_names, - serialized_settings, - hardware_compatible) {} - -TRTEngine::TRTEngine(std::vector serialized_info) try +TRTEngine::TRTEngine(std::vector serialized_info) : TRTEngine( serialized_info[NAME_IDX], serialized_info[ENGINE_IDX], RTDevice(serialized_info[DEVICE_IDX]), split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM), split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM), - serialized_info[SERIALIZED_COMPILE_SETTINGS_IDX], - static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX]))) { -} catch (const std::exception& e) { - std::cerr << "No compilation settings is passed in" << std::endl; - handleConstructorError(serialized_info); -} + static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX])), + serialized_info[SERIALIZED_COMPILE_SETTINGS_IDX]) {} TRTEngine::TRTEngine( const std::string& mod_name, @@ -79,24 +60,8 @@ TRTEngine::TRTEngine( const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, - bool hardware_compatible) - : TRTEngine( - mod_name, - serialized_engine, - cuda_device, - _in_binding_names, - _out_binding_names, - "", - hardware_compatible) {} - -TRTEngine::TRTEngine( - const std::string& mod_name, - const std::string& serialized_engine, - const RTDevice& cuda_device, - const std::vector& _in_binding_names, - const std::vector& _out_binding_names, - const std::string& serialized_settings, - bool hardware_compatible) { + bool hardware_compatible, + const std::string& serialized_settings) { this->hardware_compatible = hardware_compatible; this->serialized_settings = serialized_settings; auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible); @@ -225,18 +190,6 @@ TRTEngine::TRTEngine( LOG_DEBUG(*this); } -void TRTEngine::handleConstructorError(std::vector serialized_info) { - // Delegate to the fallback constructor - *this = TRTEngine( - serialized_info[NAME_IDX], - serialized_info[ENGINE_IDX], - RTDevice(serialized_info[DEVICE_IDX]), - split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM), - split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM), - "", - static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX]))); -} - TRTEngine::~TRTEngine() { trt_engine_profiler.reset(); exec_ctx.reset(); diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 8520f621ba..5fe734f332 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -42,14 +42,8 @@ struct TRTEngine : torch::CustomClassHolder { const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, - bool hardware_compatible = false); - TRTEngine( - const std::string& serialized_engine, - const RTDevice& cuda_device, - const std::vector& in_binding_names, - const std::vector& out_binding_names, - const std::string& serialized_settings, - bool hardware_compatible = false); + bool hardware_compatible = false, + const std::string& serialized_settings = ""); TRTEngine(std::vector serialized_info); TRTEngine( const std::string& mod_name, @@ -57,16 +51,8 @@ struct TRTEngine : torch::CustomClassHolder { const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, - bool hardware_compatible = false); - TRTEngine( - const std::string& mod_name, - const std::string& serialized_engine, - const RTDevice& cuda_device, - const std::vector& in_binding_names, - const std::vector& out_binding_names, - const std::string& serialized_settings, - bool hardware_compatible = false); - void handleConstructorError(std::vector serialized_info); + bool hardware_compatible = false, + const std::string& serialized_settings = ""); TRTEngine& operator=(const TRTEngine& other); std::string to_str() const; static void verify_serialization_fmt(const std::vector& serialized_info); From bb5fdba5a49f6c871a12b8b6b2e664c222974779 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Jun 2024 10:46:06 -0700 Subject: [PATCH 33/70] Fixed a type of setting storage checking --- py/torch_tensorrt/dynamo/_refit.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index f9c5390347..f00182aa92 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -165,6 +165,9 @@ def refit_module_weights( if "engine" in name ] encoded_settings = compiled_submodules[0][1].__getstate__()[0][7] + assert ( + encoded_settings != "" + ), "Settings are not saved in the engine. Please recompile the engine." settings = get_settings(encoded_settings) # Handle torch modules compiled_submodules_map = dict(compiled_submodules) From e47bcb25e38a6d62b4208f5be40c45c22451a12c Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Jun 2024 13:23:11 -0700 Subject: [PATCH 34/70] Renamed settings to metadata --- core/runtime/TRTEngine.cpp | 10 +++++----- core/runtime/TRTEngine.h | 8 +++++--- core/runtime/register_jit_hooks.cpp | 2 +- core/runtime/runtime.h | 2 +- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 168058a927..6e6080a353 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -34,7 +34,7 @@ TRTEngine::TRTEngine( const std::vector& _in_binding_names, const std::vector& _out_binding_names, bool hardware_compatible, - const std::string& serialized_settings) + const std::string& serialized_metadata) : TRTEngine( "deserialized_trt", serialized_engine, @@ -42,7 +42,7 @@ TRTEngine::TRTEngine( _in_binding_names, _out_binding_names, hardware_compatible, - serialized_settings) {} + serialized_metadata) {} TRTEngine::TRTEngine(std::vector serialized_info) : TRTEngine( @@ -52,7 +52,7 @@ TRTEngine::TRTEngine(std::vector serialized_info) split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM), split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM), static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX])), - serialized_info[SERIALIZED_COMPILE_SETTINGS_IDX]) {} + serialized_info[SERIALIZED_METADATA_IDX]) {} TRTEngine::TRTEngine( const std::string& mod_name, @@ -61,9 +61,9 @@ TRTEngine::TRTEngine( const std::vector& _in_binding_names, const std::vector& _out_binding_names, bool hardware_compatible, - const std::string& serialized_settings) { + const std::string& serialized_metadata) { this->hardware_compatible = hardware_compatible; - this->serialized_settings = serialized_settings; + this->serialized_metadata = serialized_metadata; auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible); TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine"); device_info = most_compatible_device.value(); diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 5fe734f332..af6bdcec6f 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -35,7 +35,9 @@ struct TRTEngine : torch::CustomClassHolder { std::vector out_binding_names = {}; // ITO: PYT IDX bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode - std::string serialized_settings; + std::string serialized_metadata; // This is a base64 encoded pkl object used to store metadata such as settings used + // in compilation + ~TRTEngine(); TRTEngine( const std::string& serialized_engine, @@ -43,7 +45,7 @@ struct TRTEngine : torch::CustomClassHolder { const std::vector& in_binding_names, const std::vector& out_binding_names, bool hardware_compatible = false, - const std::string& serialized_settings = ""); + const std::string& serialized_metadata = ""); TRTEngine(std::vector serialized_info); TRTEngine( const std::string& mod_name, @@ -52,7 +54,7 @@ struct TRTEngine : torch::CustomClassHolder { const std::vector& in_binding_names, const std::vector& out_binding_names, bool hardware_compatible = false, - const std::string& serialized_settings = ""); + const std::string& serialized_metadata = ""); TRTEngine& operator=(const TRTEngine& other); std::string to_str() const; static void verify_serialization_fmt(const std::vector& serialized_info); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index be0b5216d0..dd82839642 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -102,7 +102,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = serialize_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(self->in_binding_names); serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names); serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0"; - serialize_info[SERIALIZED_COMPILE_SETTINGS_IDX] = self->serialized_settings; + serialize_info[SERIALIZED_METADATA_IDX] = self->serialized_metadata; LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled")); return serialize_info; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 4ac755151e..e48357503d 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -25,7 +25,7 @@ typedef enum { INPUT_BINDING_NAMES_IDX, OUTPUT_BINDING_NAMES_IDX, HW_COMPATIBLE_IDX, - SERIALIZED_COMPILE_SETTINGS_IDX, + SERIALIZED_METADATA_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; From e6e71ca802af054acbc60b2b1028d4acf5c7bb3b Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Jun 2024 13:26:41 -0700 Subject: [PATCH 35/70] Added refit to __init__ of dynamo --- py/torch_tensorrt/dynamo/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 335faa7007..83597db0b6 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -9,6 +9,7 @@ if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): from ._compiler import compile, convert_module_to_trt_engine from ._exporter import export + from ._refit import refit_module_weights from ._settings import CompilationSettings from ._SourceIR import SourceIR from ._tracer import trace From cf43a79e280db871a15e7f17025a87011ad9bc6e Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Jun 2024 14:01:09 -0700 Subject: [PATCH 36/70] Added docstring. Added support for dynamic shape --- py/torch_tensorrt/dynamo/_refit.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index f00182aa92..300e69f036 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -145,9 +145,24 @@ def refit_module_weights( ) -> torch.fx.GraphModule: """ Refit a compiled graph module with ExportedProgram + Args: + compiled_module: compiled TensorRT module that needs to be refitted. + This compiled_module should be compmiled by torch_tensorrt.dynamo.compile + or load it from disk using trt.load. + new_weight_module: exported program with the updated weights + inputs: sample inputs + verify_output: whether to verify output of refitted module + Returns: + A new compiled TensorRT module that has the updated weights. """ inline_module = False - raw_inputs = copy.deepcopy(inputs) + raw_inputs = [] + if verify_output: + for inp in inputs: + if isinstance(inp, torch.Tensor): + raw_inputs.append(copy.deepcopy(inp)) + elif isinstance(inp, Input): + raw_inputs.append(copy.deepcopy(input.torch_tensor)) if isinstance(compiled_module, ExportedProgram): inline_module = True compiled_module = compiled_module.module() @@ -332,14 +347,16 @@ def refit_module_weights( setattr(compiled_module, f"{name}_engine", refitted_engine) if verify_output: - check_output( + if check_output( new_module=new_gm, refitted_module=compiled_module, inputs=raw_inputs, - ) - logger.info("Refit Successfully!") + ): + logger.info("Refitting Succeed!") + else: + logger.error("Refitting Failed! The outputs do not match.") else: - logger.info("Refit Completed! Output verification skipped.") + logger.info("Refitting Completed! Output verification skipped.") return compiled_module From cfeb6bf7a1e4e1c489f0859283ab779b6173cf70 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Jun 2024 14:01:58 -0700 Subject: [PATCH 37/70] Chagned the check_output function to return a boolean --- py/torch_tensorrt/dynamo/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 36d208c52f..fd6fcf72f2 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -318,13 +318,13 @@ def check_output( new_module: torch.fx.GraphModule, refitted_module: torch.fx.GraphModule, inputs: tuple[Any, ...], -) -> None: - # inputs = [t.contiguous() for t in inputs] +) -> bool: old_outputs, new_outputs = refitted_module(*inputs), new_module(*inputs) for old_output, new_output in zip(old_outputs, new_outputs): if isinstance(old_output, torch.Tensor) and isinstance( new_outputs, torch.Tensor ): - assert torch.allclose( - old_output, new_output, 1e-2, 1e-2 - ), "Refit Result is not correct. Refit failed" + if not torch.allclose(old_output, new_output, 1e-2, 1e-2): + return False + + return True From a092229771d478bb42e9cb09676ba57e5dec802e Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Jun 2024 14:20:23 -0700 Subject: [PATCH 38/70] Chagned get_settings to a static method in TorchTensorRTModule --- py/torch_tensorrt/dynamo/_refit.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 300e69f036..830b8c4700 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -3,7 +3,6 @@ import collections.abc import copy import logging -import pickle from typing import Any, Sequence, Tuple import numpy as np @@ -144,12 +143,13 @@ def refit_module_weights( verify_output: bool = True, ) -> torch.fx.GraphModule: """ - Refit a compiled graph module with ExportedProgram + Refit a compiled graph module with ExportedProgram. This performs weight updates in compiled_module without recompiling the engine. + Args: compiled_module: compiled TensorRT module that needs to be refitted. This compiled_module should be compmiled by torch_tensorrt.dynamo.compile or load it from disk using trt.load. - new_weight_module: exported program with the updated weights + new_weight_module: exported program with the updated weights. This one should have the same model architecture as the compiled module. inputs: sample inputs verify_output: whether to verify output of refitted module Returns: @@ -183,7 +183,7 @@ def refit_module_weights( assert ( encoded_settings != "" ), "Settings are not saved in the engine. Please recompile the engine." - settings = get_settings(encoded_settings) + settings = TorchTensorRTModule.decode_metadata(encoded_settings) # Handle torch modules compiled_submodules_map = dict(compiled_submodules) for name, submodule in compiled_module.named_children(): @@ -371,9 +371,3 @@ def get_engine_from_encoded_engine( serialized_engine = base64.b64decode(encoded_engine) engine = runtime.deserialize_cuda_engine(serialized_engine) return engine - - -def get_settings(encoded_settings: str) -> Any: - dumped_settings = base64.b64decode(encoded_settings.encode("utf-8")) - settings = pickle.loads(dumped_settings) - return settings From 4819a6d69bf05784a45b112ed69f102c714ab1db Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Jun 2024 15:24:00 -0700 Subject: [PATCH 39/70] Simplified the code --- py/torch_tensorrt/dynamo/_refit.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 830b8c4700..1046a7896f 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -156,13 +156,6 @@ def refit_module_weights( A new compiled TensorRT module that has the updated weights. """ inline_module = False - raw_inputs = [] - if verify_output: - for inp in inputs: - if isinstance(inp, torch.Tensor): - raw_inputs.append(copy.deepcopy(inp)) - elif isinstance(inp, Input): - raw_inputs.append(copy.deepcopy(input.torch_tensor)) if isinstance(compiled_module, ExportedProgram): inline_module = True compiled_module = compiled_module.module() @@ -350,7 +343,7 @@ def refit_module_weights( if check_output( new_module=new_gm, refitted_module=compiled_module, - inputs=raw_inputs, + inputs=get_torch_inputs(inputs, device), ): logger.info("Refitting Succeed!") else: From bd77f22f3358afb05d0f113ac0a6b1b56b2323b4 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Jun 2024 16:45:44 -0700 Subject: [PATCH 40/70] Added three testcases --- tests/py/dynamo/models/test_model_refit.py | 179 +++++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 tests/py/dynamo/models/test_model_refit.py diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py new file mode 100644 index 0000000000..9dfc7207ce --- /dev/null +++ b/tests/py/dynamo/models/test_model_refit.py @@ -0,0 +1,179 @@ +import time +import unittest + +import numpy as np +import pytest +import tensorrt as trt +import torch +import torch.nn.functional as F +import torch_tensorrt as torchtrt +import torchvision.models as models +from torch import nn + +# from torch import nn +from torch_tensorrt.dynamo import refit_module_weights +from torch_tensorrt.dynamo._refit import ( + construct_refit_mapping, + get_engine_from_encoded_engine, +) +from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions +from torch_tensorrt.logging import TRT_LOGGER + +assertions = unittest.TestCase() + + +@pytest.mark.unit +def test_mapping(): + + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + trt_input = [ + torchtrt.Input(i.shape, dtype=torch.float, format=torch.contiguous_format) + for i in inputs + ] + enabled_precisions = {torch.float} + debug = False + min_block_size = 0 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + refit=True, + ) + settings = trt_gm._run_on_acc_0.settings + runtime = trt.Runtime(TRT_LOGGER) + + engine_info = trt_gm._run_on_acc_0.engine.__getstate__()[0] + engine = get_engine_from_encoded_engine(engine_info[3], runtime) + + exp_program2 = exp_program2.run_decompositions( + get_decompositions(settings.enable_experimental_decompositions) + ) + new_gm = exp_program2.module() + new_gm = apply_lowering_passes(new_gm, inputs) + mapping = construct_refit_mapping(new_gm, trt_input, settings) + + refitter = trt.Refitter(engine, TRT_LOGGER) + weight_list = refitter.get_all_weights() + for weight in weight_list: + assertions.assertTrue( + weight in mapping, + msg=f"Weight is not found in mapping. Test failed", + ) + # Clean up model env + torch._dynamo.reset() + + +@pytest.mark.unit +def test_refit_one_engine(): + + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 0 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + refit=True, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + +@pytest.mark.unit +def test_refit_multiple_engine(): + + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 12, 3, padding=1) + self.bn = nn.BatchNorm2d(12) + self.conv2 = nn.Conv2d(12, 12, 3, padding=1) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.bn(x) + x = self.conv2(x) + x = F.relu(x) + return x + + model = net().eval().to("cuda") + model2 = net().eval().to("cuda") + + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 0 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + torch_executed_ops = {torch.ops.aten.convolution.default} + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + refit=True, + torch_executed_ops=torch_executed_ops, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() From 1b3a769ac8d12f54bec461b209e82d9169da26f2 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Jun 2024 16:46:15 -0700 Subject: [PATCH 41/70] Supported torch ops in settings --- .../dynamo/runtime/_TorchTensorRTModule.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 0a67e12cc7..2df6337efd 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -105,17 +105,27 @@ def __init__( TorchTensorRTModule._pack_binding_names(self.input_binding_names), TorchTensorRTModule._pack_binding_names(self.output_binding_names), str(int(hardware_compatible)), - self.encode_settings(settings), + self.encode_metadata(settings), ] ) else: self.engine = None - def encode_settings(self, settings: Any) -> str: + def encode_metadata(self, settings: Any) -> str: + settings.torch_executed_ops = { + f"torch.ops.{op.__str__()}" for op in settings.torch_executed_ops + } dumped_settings = pickle.dumps(settings) encoded_settings = base64.b64encode(dumped_settings).decode("utf-8") return encoded_settings + @staticmethod + def decode_metadata(encoded_settings: str) -> Any: + dumped_settings = base64.b64decode(encoded_settings.encode("utf-8")) + settings = pickle.loads(dumped_settings) + settings.torch_executed_ops = {eval(op) for op in settings.torch_executed_ops} + return settings + def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt: return ( self.name, @@ -150,6 +160,7 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: self.hardware_compatible = ( bool(int(state[1][6])) if state[1] is not None else False ) + self.settings = TorchTensorRTModule.decode_metadata(serialized_engine_info[7]) def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: """Implementation of the forward pass for a TensorRT engine From 1456ad959a8baff0e8681e94e8d8c1f4068cec7b Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Jun 2024 17:06:01 -0700 Subject: [PATCH 42/70] Updated the example --- examples/dynamo/refit_engine_example.py | 46 +++---------------------- 1 file changed, 4 insertions(+), 42 deletions(-) diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index a6b41f895c..2fd1b81ec8 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -26,11 +26,9 @@ import numpy as np import torch -import torch.nn.functional as F import torch_tensorrt as trt import torchvision.models as models -from torch import nn -from torch_tensorrt.dynamo._refit import refit_module_weights +from torch_tensorrt.dynamo import refit_module_weights np.random.seed(0) torch.manual_seed(0) @@ -61,6 +59,7 @@ ) # Output is a torch.fx.GraphModule # Save the graph module as an exported program +# This is only supported when use_python_runtime = False trt.save(trt_gm, "./compiled.ep", inputs=inputs) @@ -95,42 +94,5 @@ # Alterative Workflow using Python Runtime # ----------------------------- -# Currently python runtime does not support engine serialization. So the refitt will be done in the same runtime. -# This usecase is more useful when you need to switch different weights during runtime, such as using Stable Diffusion. - -model = models.resnet18(pretrained=False).eval().to("cuda") -exp_program = torch.export.export(model, tuple(inputs)) -enabled_precisions = {torch.float} -debug = False -workspace_size = 20 << 30 -min_block_size = 0 -use_python_runtime = True -torch_executed_ops = {} -trt_gm = trt.dynamo.compile( - exp_program, - tuple(inputs), - use_python_runtime=use_python_runtime, - enabled_precisions=enabled_precisions, - debug=debug, - min_block_size=min_block_size, - torch_executed_ops=torch_executed_ops, - refit=True, -) - -model2 = models.resnet18(pretrained=True).eval().to("cuda") -exp_program2 = torch.export.export(model2, tuple(inputs)) - -new_trt_gm = refit_module_weights( - compiled_module=compiled_trt_gm, - new_weight_module=exp_program2, - inputs=inputs, -) - -# Check the output -expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs) -for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): - assert torch.allclose( - expected_output, refitted_output, 1e-2, 1e-2 - ), "Refit Result is not correct. Refit failed" - -print("Refit successfully!") +# Currently python runtime does not support engine serialization. So the refitting will be done in the same runtime. +# This usecase is more useful when you need to switch different weights in the same runtime, such as using Stable Diffusion. From 1acfe31e6b8ad76b42d1d44f823a3c66a60d1bdd Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Jun 2024 17:30:25 -0700 Subject: [PATCH 43/70] Wrote 6 test cases for refitting feature to cover different scenarios --- tests/py/dynamo/models/test_model_refit.py | 144 ++++++++++++++++++++- 1 file changed, 143 insertions(+), 1 deletion(-) diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 9dfc7207ce..92202648ff 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -1,3 +1,4 @@ +import os import time import unittest @@ -18,6 +19,7 @@ ) from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions from torch_tensorrt.logging import TRT_LOGGER +from transformers import BertModel assertions = unittest.TestCase() @@ -117,6 +119,142 @@ def test_refit_one_engine(): torch._dynamo.reset() +@pytest.mark.unit +def test_refit_one_engine_bert(): + inputs = [ + torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), + ] + model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda") + model2 = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda") + nn.init.xavier_normal_(model2.embeddings.word_embeddings.weight) + enabled_precisions = {torch.float} + debug = False + min_block_size = 0 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + refit=True, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + +@pytest.mark.unit +def test_refit_one_engine_inline_runtime(): + + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 0 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + refit=True, + ) + trt.save(trt_gm, "./compiled.ep", inputs=inputs) + trt_gm = torch.export.load("./compiled.ep") + os.remove("./compiled.ep") + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + +@pytest.mark.unit +def test_refit_one_engine_python_runtime(): + + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 0 + use_python_runtime = True + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + refit=True, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + @pytest.mark.unit def test_refit_multiple_engine(): @@ -126,14 +264,18 @@ def __init__(self): self.conv1 = nn.Conv2d(3, 12, 3, padding=1) self.bn = nn.BatchNorm2d(12) self.conv2 = nn.Conv2d(12, 12, 3, padding=1) + self.fc1 = nn.Linear(12 * 56 * 56, 10) def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.bn(x) + x = F.max_pool2d(x, (2, 2)) x = self.conv2(x) x = F.relu(x) - return x + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) model = net().eval().to("cuda") model2 = net().eval().to("cuda") From d38e42257a64d576be4bffa8dc495e198977efb4 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Jun 2024 17:40:32 -0700 Subject: [PATCH 44/70] Fixed a bug in tests --- tests/py/dynamo/models/test_model_refit.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 92202648ff..4d73da554c 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -156,6 +156,10 @@ def test_refit_one_engine_bert(): *inputs ) for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + if not isinstance(expected_output, torch.Tensor) or not isinstance( + refitted_output, torch.Tensor + ): + continue assertions.assertTrue( torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), "Refit Result is not correct. Refit failed", @@ -188,7 +192,7 @@ def test_refit_one_engine_inline_runtime(): min_block_size=min_block_size, refit=True, ) - trt.save(trt_gm, "./compiled.ep", inputs=inputs) + torchtrt.save(trt_gm, "./compiled.ep", inputs=inputs) trt_gm = torch.export.load("./compiled.ep") os.remove("./compiled.ep") new_trt_gm = refit_module_weights( From 880afde4480493b8a522968aed4d18f2919eb3c5 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Jun 2024 17:40:47 -0700 Subject: [PATCH 45/70] Delete settings check --- py/torch_tensorrt/dynamo/_refit.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 1046a7896f..70a7b30cf7 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -188,8 +188,6 @@ def refit_module_weights( submodule, (PythonTorchTensorRTModule, TorchTensorRTModule) ): continue - if settings is not None: - assert settings == submodule.settings settings = submodule.settings if settings.debug: From 2dc5bfa4fe680794faa9daa547d0a7dd4212ee3b Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 14 Jun 2024 17:41:04 -0700 Subject: [PATCH 46/70] Fixed a bug of modifing settings inplace --- py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 2df6337efd..89d23b6da4 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import copy import logging import pickle from typing import Any, List, Optional, Tuple @@ -94,7 +95,7 @@ def __init__( ) self.name = name self.hardware_compatible = hardware_compatible - self.settings = settings + self.settings = copy.deepcopy(settings) if serialized_engine is not None: self.engine = torch.classes.tensorrt.Engine( [ @@ -112,6 +113,7 @@ def __init__( self.engine = None def encode_metadata(self, settings: Any) -> str: + settings = copy.deepcopy(settings) settings.torch_executed_ops = { f"torch.ops.{op.__str__()}" for op in settings.torch_executed_ops } From 410689c958260ac47e02d751a17a2946acfe9d45 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 17 Jun 2024 10:02:38 -0700 Subject: [PATCH 47/70] added it to //docsrc/py_api/dynamo.rst so that it gets rendered in the docs --- docsrc/py_api/dynamo.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docsrc/py_api/dynamo.rst b/docsrc/py_api/dynamo.rst index 6b4a527663..12fa5e76c1 100644 --- a/docsrc/py_api/dynamo.rst +++ b/docsrc/py_api/dynamo.rst @@ -24,7 +24,7 @@ Functions .. autofunction:: convert_module_to_trt_engine - +.. autofunction:: refit_module_weights Classes -------- From 381f14abb817042e185de0d1c7685780f80d5671 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 18 Jun 2024 11:18:13 -0700 Subject: [PATCH 48/70] Added reference to doc --- examples/dynamo/README.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index 7191c02fa0..574d891a93 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -11,3 +11,4 @@ a number of ways you can leverage this backend to accelerate inference. * :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API * :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile`` * :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines +* :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights From eebe883f1e9dcf61bd280ec63f0c00b19611634e Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 18 Jun 2024 16:25:10 -0700 Subject: [PATCH 49/70] Changed the default outputcheck to false --- py/torch_tensorrt/dynamo/_refit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 70a7b30cf7..056a8d4235 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -140,7 +140,7 @@ def refit_module_weights( compiled_module: torch.fx.GraphModule | ExportedProgram, new_weight_module: ExportedProgram, inputs: Tuple[Any, ...], - verify_output: bool = True, + verify_output: bool = False, ) -> torch.fx.GraphModule: """ Refit a compiled graph module with ExportedProgram. This performs weight updates in compiled_module without recompiling the engine. From 2a3d567f519fa5a0d57b29e3e7a4c788966f7ccb Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 18 Jun 2024 16:26:36 -0700 Subject: [PATCH 50/70] Chagned the assertion --- py/torch_tensorrt/dynamo/_refit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 056a8d4235..fb331b6259 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -255,11 +255,11 @@ def refit_module_weights( if inline_module: assert {sm[0] for sm in new_partitioned_module.named_children()} == set( compiled_submodules_map.keys() - ), "The compiled module is incompatible with the new module!" + ), "New weights module is not compatible with previously compiled Torch-TensorRT module" else: assert {sm[0] for sm in new_partitioned_module.named_children()} == { sm[0] for sm in compiled_module.named_children() - }, "The compiled module is incompatible with the new module!" + }, "New weights module is not compatible with previously compiled Torch-TensorRT module" # 2. TODO: Check the hash of source fx.Graph and new fx.Graph # Iterate over all components that can be accelerated From 003380af89ef76589bd433a6de037b64f8b121d6 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 18 Jun 2024 16:27:21 -0700 Subject: [PATCH 51/70] Renamed the imported name --- examples/dynamo/refit_engine_example.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index 2fd1b81ec8..80c194cd16 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -26,7 +26,7 @@ import numpy as np import torch -import torch_tensorrt as trt +import torch_tensorrt as torch_trt import torchvision.models as models from torch_tensorrt.dynamo import refit_module_weights @@ -47,7 +47,7 @@ min_block_size = 0 use_python_runtime = False torch_executed_ops = {} -trt_gm = trt.dynamo.compile( +trt_gm = torch_trt.dynamo.compile( exp_program, tuple(inputs), use_python_runtime=use_python_runtime, @@ -60,7 +60,7 @@ # Save the graph module as an exported program # This is only supported when use_python_runtime = False -trt.save(trt_gm, "./compiled.ep", inputs=inputs) +torch_trt.save(trt_gm, "./compiled.ep", inputs=inputs) # %% @@ -72,7 +72,7 @@ exp_program2 = torch.export.export(model2, tuple(inputs)) -compiled_trt_gm = trt.load("./compiled.ep") +compiled_trt_gm = torch_trt.load("./compiled.ep") # This returns a new module with updated weights new_trt_gm = refit_module_weights( From 323db97575a0878a7c58575741b6b268a82c97dc Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 18 Jun 2024 16:33:38 -0700 Subject: [PATCH 52/70] Renamed --- examples/dynamo/refit_engine_example.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index 80c194cd16..36d8353a75 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -72,11 +72,11 @@ exp_program2 = torch.export.export(model2, tuple(inputs)) -compiled_trt_gm = torch_trt.load("./compiled.ep") +compiled_trt_ep = torch_trt.load("./compiled.ep") # This returns a new module with updated weights new_trt_gm = refit_module_weights( - compiled_module=compiled_trt_gm, + compiled_module=compiled_trt_ep, new_weight_module=exp_program2, inputs=inputs, ) From de0ab949eb6d4f43d0b42d48b240d93e017197c8 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 18 Jun 2024 16:44:58 -0700 Subject: [PATCH 53/70] Fixed a bug of serialized info signature --- py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 89d23b6da4..d015d32188 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) SerializedTensorRTEngineFmt = Tuple[ - str, str, str, bytes, str, str, bytes, str + str, str, str, bytes, str, str, str, bytes ] # Defined in //core/runtime/register_jit_hooks.cpp SerializedTorchTensorRTModuleFmt = Tuple[ str, Optional[SerializedTensorRTEngineFmt], List[str], List[str] @@ -122,7 +122,7 @@ def encode_metadata(self, settings: Any) -> str: return encoded_settings @staticmethod - def decode_metadata(encoded_settings: str) -> Any: + def decode_metadata(encoded_settings: bytes) -> Any: dumped_settings = base64.b64decode(encoded_settings.encode("utf-8")) settings = pickle.loads(dumped_settings) settings.torch_executed_ops = {eval(op) for op in settings.torch_executed_ops} From de3da26423cb04f08d86c8547d19707318885ab1 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 18 Jun 2024 16:55:46 -0700 Subject: [PATCH 54/70] Changed the refit condition check --- py/torch_tensorrt/dynamo/_refit.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index fb331b6259..d050e4dd9a 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -163,7 +163,7 @@ def refit_module_weights( compiled_module = copy.deepcopy(compiled_module) # Get the settings and check the setting to be uniform - settings: Any = None + settings: CompilationSettings = None if inline_module: # Obtain the settings @@ -175,7 +175,7 @@ def refit_module_weights( encoded_settings = compiled_submodules[0][1].__getstate__()[0][7] assert ( encoded_settings != "" - ), "Settings are not saved in the engine. Please recompile the engine." + ), "Settings are not saved in the engine. Please recompile the engine with refit=True." settings = TorchTensorRTModule.decode_metadata(encoded_settings) # Handle torch modules compiled_submodules_map = dict(compiled_submodules) @@ -190,6 +190,10 @@ def refit_module_weights( continue settings = submodule.settings + assert ( + settings.refit + ), "Refitting is not enabled. Please recompile the engine with refit=True." + if settings.debug: set_log_level(logger.parent, logging.DEBUG) From 91c6036d83b6fb56a578ef2628f8007f1a53e36f Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 18 Jun 2024 16:58:04 -0700 Subject: [PATCH 55/70] Changed the file path in test file --- tests/py/dynamo/models/test_model_refit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 4d73da554c..4493a4ca40 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -1,4 +1,5 @@ import os +import tempfile import time import unittest @@ -171,7 +172,7 @@ def test_refit_one_engine_bert(): @pytest.mark.unit def test_refit_one_engine_inline_runtime(): - + trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] @@ -192,9 +193,8 @@ def test_refit_one_engine_inline_runtime(): min_block_size=min_block_size, refit=True, ) - torchtrt.save(trt_gm, "./compiled.ep", inputs=inputs) - trt_gm = torch.export.load("./compiled.ep") - os.remove("./compiled.ep") + torchtrt.save(trt_gm, trt_ep_path, inputs=inputs) + trt_gm = torch.export.load(trt_ep_path) new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, From bd43882972c135d3be724f86d3a6a5e1095bca45 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 18 Jun 2024 17:00:07 -0700 Subject: [PATCH 56/70] Fixed minor format --- tests/py/dynamo/models/test_model_refit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 4493a4ca40..d6b892e16e 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -254,8 +254,8 @@ def test_refit_one_engine_python_runtime(): torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), "Refit Result is not correct. Refit failed", ) - # Clean up model env + # Clean up model env torch._dynamo.reset() From 5ef9af75d1e4a28b7d19e47c8bb94a08f818e501 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 18 Jun 2024 17:08:21 -0700 Subject: [PATCH 57/70] Deleted setting repetitions --- py/torch_tensorrt/dynamo/conversion/_conversion.py | 4 ---- .../dynamo/runtime/_PythonTorchTensorRTModule.py | 12 ++++++------ .../dynamo/runtime/_TorchTensorRTModule.py | 9 +++++---- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index dfcfbbf48a..b335756068 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -113,8 +113,6 @@ def convert_module( engine=interpreter_result.engine, input_names=list(interpreter_result.input_names), output_names=list(interpreter_result.output_names), - target_device=settings.device, - profiling_enabled=settings.debug, settings=settings, ) @@ -130,7 +128,5 @@ def convert_module( name=name, input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), - target_device=settings.device, - hardware_compatible=settings.hardware_compatible, settings=settings, ) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index e315e5fe6e..9e0a78e9e7 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -32,8 +32,6 @@ def __init__( engine: bytes, input_names: Optional[List[str]] = None, output_names: Optional[List[str]] = None, - target_device: Device = Device._current_device(), - profiling_enabled: Optional[bool] = None, settings: Any = None, ): super(PythonTorchTensorRTModule, self).__init__() @@ -46,13 +44,15 @@ def __init__( self.input_names = input_names if input_names is not None else [] self.output_names = output_names if output_names is not None else [] self.initialized = False - self.target_device_id = target_device.gpu_id + self.target_device_id = ( + settings.device.gpu_id + if settings.device is not None + else Device._current_device().gpu_id + ) self.target_device_properties = torch.cuda.get_device_properties( self.target_device_id ) - self.profiling_enabled = ( - profiling_enabled if profiling_enabled is not None else False - ) + self.profiling_enabled = settings.debug if settings.debug is not None else False self.settings = settings self._initialize() diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index d015d32188..c24fb24394 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -45,8 +45,6 @@ def __init__( name: str = "", input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, - target_device: Device = Device._current_device(), - hardware_compatible: bool = False, settings: Any = None, ): """__init__ method for torch_tensorrt.dynamo.runtime._TorchTensorRTModule.TorchTensorRTModule @@ -94,7 +92,10 @@ def __init__( output_binding_names if output_binding_names is not None else [] ) self.name = name - self.hardware_compatible = hardware_compatible + target_device = ( + settings.device if settings.device is not None else Device._current_device() + ) + self.hardware_compatible = settings.hardware_compatible self.settings = copy.deepcopy(settings) if serialized_engine is not None: self.engine = torch.classes.tensorrt.Engine( @@ -105,7 +106,7 @@ def __init__( serialized_engine, TorchTensorRTModule._pack_binding_names(self.input_binding_names), TorchTensorRTModule._pack_binding_names(self.output_binding_names), - str(int(hardware_compatible)), + str(int(self.hardware_compatible)), self.encode_metadata(settings), ] ) From 888242558f97f797ad97ca83b2318809c7299ada Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 18 Jun 2024 17:23:44 -0700 Subject: [PATCH 58/70] Changed min_block_size to 1 --- tests/py/dynamo/models/test_model_refit.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index d6b892e16e..efa6b1fb03 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -37,7 +37,7 @@ def test_mapping(): ] enabled_precisions = {torch.float} debug = False - min_block_size = 0 + min_block_size = 1 use_python_runtime = False exp_program = torch.export.export(model, tuple(inputs)) @@ -84,7 +84,7 @@ def test_refit_one_engine(): inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False - min_block_size = 0 + min_block_size = 1 use_python_runtime = False exp_program = torch.export.export(model, tuple(inputs)) @@ -130,7 +130,7 @@ def test_refit_one_engine_bert(): nn.init.xavier_normal_(model2.embeddings.word_embeddings.weight) enabled_precisions = {torch.float} debug = False - min_block_size = 0 + min_block_size = 1 use_python_runtime = False exp_program = torch.export.export(model, tuple(inputs)) @@ -178,7 +178,7 @@ def test_refit_one_engine_inline_runtime(): inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False - min_block_size = 0 + min_block_size = 1 use_python_runtime = False exp_program = torch.export.export(model, tuple(inputs)) @@ -223,7 +223,7 @@ def test_refit_one_engine_python_runtime(): inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False - min_block_size = 0 + min_block_size = 1 use_python_runtime = True exp_program = torch.export.export(model, tuple(inputs)) @@ -287,7 +287,7 @@ def forward(self, x): inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False - min_block_size = 0 + min_block_size = 1 use_python_runtime = False exp_program = torch.export.export(model, tuple(inputs)) From 7f1f958f31a846d2391a7af8022b16f3103f9e07 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 18 Jun 2024 17:27:22 -0700 Subject: [PATCH 59/70] Added comments --- py/torch_tensorrt/dynamo/_refit.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index d050e4dd9a..c09598ce2a 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -86,6 +86,9 @@ def construct_refit_mapping( layer = net[i] layer_type: str = layer.type.name if layer_type in MODULE_MAP: + # Cast the parent class to child class to access attributes + # For example: ILayer does not have ILayer.kernal/ILayer.bias + # So we cast it to IConvolutionLayer and access the attributes layer.__class__ = MODULE_MAP[layer_type][0] for weight_type, weight_name in MODULE_MAP[layer_type][1]: weight = layer.__getattribute__(weight_type).copy() From b8e023dd70ca2451f9d20cafae949b35fa653c74 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 24 Jun 2024 13:55:33 -0700 Subject: [PATCH 60/70] Merged two if statements --- py/torch_tensorrt/dynamo/_refit.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index c09598ce2a..b84eb80c49 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -256,10 +256,7 @@ def refit_module_weights( # Preprocess the partitioned module to be in the same format as the inline module inline_torch_modules(new_partitioned_module) new_partitioned_module.delete_all_unused_submodules() - - # Check whether two modules have the same subcomponents - # 1. Check the number of partitions and name - if inline_module: + # Check the number of partitions and name assert {sm[0] for sm in new_partitioned_module.named_children()} == set( compiled_submodules_map.keys() ), "New weights module is not compatible with previously compiled Torch-TensorRT module" From df9cd39c0f23b5e893d903ed797c2671693c4054 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 26 Jun 2024 15:27:57 -0700 Subject: [PATCH 61/70] Chagned the weight type --- py/torch_tensorrt/dynamo/_refit.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index b84eb80c49..1dd8786914 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -92,11 +92,7 @@ def construct_refit_mapping( layer.__class__ = MODULE_MAP[layer_type][0] for weight_type, weight_name in MODULE_MAP[layer_type][1]: weight = layer.__getattribute__(weight_type).copy() - weight_dtype = ( - layer.precision - if layer.precision_is_set - else dtype.try_from(weight.dtype).to(trt.DataType) - ) + weight_dtype = dtype.try_from(weight.dtype).to(trt.DataType) weight_map[f"{layer.name} {weight_name}"] = ( weight, weight_dtype, From b33fa0f77d4cc91eeaadab48c6f3990cd170a8a1 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 26 Jun 2024 15:59:56 -0700 Subject: [PATCH 62/70] Fixed hardcoded index --- core/runtime/register_jit_hooks.cpp | 9 +++++++ py/torch_tensorrt/dynamo/_refit.py | 9 +++++-- .../dynamo/runtime/_TorchTensorRTModule.py | 24 +++++++++++++------ 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index dd82839642..9ac5af5d05 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -127,6 +127,15 @@ TORCH_LIBRARY(tensorrt, m) { }); m.def( "get_logging_level", []() -> int64_t { return int64_t(util::logging::get_logger().get_reportable_log_level()); }); + m.def("ABI_TARGET_IDX", []() -> int64_t { return ABI_TARGET_IDX; }); + m.def("NAME_IDX", []() -> int64_t { return NAME_IDX; }); + m.def("DEVICE_IDX", []() -> int64_t { return DEVICE_IDX; }); + m.def("ENGINE_IDX", []() -> int64_t { return ENGINE_IDX; }); + m.def("INPUT_BINDING_NAMES_IDX", []() -> int64_t { return INPUT_BINDING_NAMES_IDX; }); + m.def("OUTPUT_BINDING_NAMES_IDX", []() -> int64_t { return OUTPUT_BINDING_NAMES_IDX; }); + m.def("HW_COMPATIBLE_IDX", []() -> int64_t { return HW_COMPATIBLE_IDX; }); + m.def("SERIALIZED_METADATA_IDX", []() -> int64_t { return SERIALIZED_METADATA_IDX; }); + m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; }); } } // namespace diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 1dd8786914..c0ad0453df 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -24,7 +24,10 @@ from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( PythonTorchTensorRTModule, ) -from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule +from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + SERIALIZED_METADATA_IDX, + TorchTensorRTModule, +) from torch_tensorrt.dynamo.utils import ( check_output, get_torch_inputs, @@ -171,7 +174,9 @@ def refit_module_weights( for name, engine in compiled_module.__dict__.items() if "engine" in name ] - encoded_settings = compiled_submodules[0][1].__getstate__()[0][7] + encoded_settings = compiled_submodules[0][1].__getstate__()[0][ + SERIALIZED_METADATA_IDX + ] assert ( encoded_settings != "" ), "Settings are not saved in the engine. Please recompile the engine with refit=True." diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index c24fb24394..64ebe696f4 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -18,6 +18,16 @@ str, Optional[SerializedTensorRTEngineFmt], List[str], List[str] ] +ABI_TARGET_IDX = torch.ops.tensorrt.ABI_TARGET_IDX() +NAME_IDX = torch.ops.tensorrt.NAME_IDX() +DEVICE_IDX = torch.ops.tensorrt.DEVICE_IDX() +ENGINE_IDX = torch.ops.tensorrt.ENGINE_IDX() +INPUT_BINDING_NAMES_IDX = torch.ops.tensorrt.INPUT_BINDING_NAMES_IDX() +OUTPUT_BINDING_NAMES_IDX = torch.ops.tensorrt.OUTPUT_BINDING_NAMES_IDX() +HW_COMPATIBLE_IDX = torch.ops.tensorrt.HW_COMPATIBLE_IDX() +SERIALIZED_METADATA_IDX = torch.ops.tensorrt.SERIALIZED_METADATA_IDX() +SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() + class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc] """TorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine. @@ -145,14 +155,14 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: serialized_engine = base64.b64decode(serialized_engine_info[3]) self.engine = torch.classes.tensorrt.Engine( [ - serialized_engine_info[0], - serialized_engine_info[1], - serialized_engine_info[2], + serialized_engine_info[ABI_TARGET_IDX], + serialized_engine_info[NAME_IDX], + serialized_engine_info[DEVICE_IDX], serialized_engine, - serialized_engine_info[4], - serialized_engine_info[5], - serialized_engine_info[6], - serialized_engine_info[7], + serialized_engine_info[INPUT_BINDING_NAMES_IDX], + serialized_engine_info[OUTPUT_BINDING_NAMES_IDX], + serialized_engine_info[HW_COMPATIBLE_IDX], + serialized_engine_info[SERIALIZED_METADATA_IDX], ] ) else: From 911984decbf658564541c407f59a3e64ce2f52c9 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 26 Jun 2024 21:03:17 -0700 Subject: [PATCH 63/70] Fixed a type causing extra overhead --- py/torch_tensorrt/dynamo/_refit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index c0ad0453df..7bfd720f6e 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -346,7 +346,7 @@ def refit_module_weights( if check_output( new_module=new_gm, refitted_module=compiled_module, - inputs=get_torch_inputs(inputs, device), + inputs=torch_inputs, ): logger.info("Refitting Succeed!") else: From 0a1c8ca12eb2ba2edf056ca0793319fd2418b194 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 27 Jun 2024 14:21:14 -0700 Subject: [PATCH 64/70] Added comments and repaced the index to enum --- py/torch_tensorrt/dynamo/_refit.py | 13 +++++++++---- .../dynamo/runtime/_TorchTensorRTModule.py | 18 +++++++++--------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 7bfd720f6e..1567ef4614 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -25,6 +25,7 @@ PythonTorchTensorRTModule, ) from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + ENGINE_IDX, SERIALIZED_METADATA_IDX, TorchTensorRTModule, ) @@ -283,14 +284,18 @@ def refit_module_weights( continue else: engine_info = compiled_submodule.__getstate__()[0] - engine = get_engine_from_encoded_engine(engine_info[3], runtime) + engine = get_engine_from_encoded_engine( + engine_info[ENGINE_IDX], runtime + ) else: compiled_submodule = getattr(compiled_module, name) if isinstance(compiled_submodule, PythonTorchTensorRTModule): engine = compiled_submodule.engine elif isinstance(compiled_submodule, TorchTensorRTModule): engine_info = compiled_submodule.engine.__getstate__()[0] - engine = get_engine_from_encoded_engine(engine_info[3], runtime) + engine = get_engine_from_encoded_engine( + engine_info[ENGINE_IDX], runtime + ) elif isinstance(compiled_submodule, torch.fx.graph_module.GraphModule): # This is graph break resulted by unsupported ops compiled_submodule.load_state_dict(new_submodule.state_dict()) @@ -331,14 +336,14 @@ def refit_module_weights( if isinstance(compiled_submodule, TorchTensorRTModule): serialized_engine = bytes(engine.serialize()) new_engine_info = list(engine_info) - new_engine_info[3] = serialized_engine + new_engine_info[ENGINE_IDX] = serialized_engine refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) compiled_submodule.engine = refitted_engine elif inline_module: serialized_engine = bytes(engine.serialize()) new_engine_info = list(engine_info) - new_engine_info[3] = serialized_engine + new_engine_info[ENGINE_IDX] = serialized_engine refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) setattr(compiled_module, f"{name}_engine", refitted_engine) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 64ebe696f4..5dbaeb0874 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -18,15 +18,15 @@ str, Optional[SerializedTensorRTEngineFmt], List[str], List[str] ] -ABI_TARGET_IDX = torch.ops.tensorrt.ABI_TARGET_IDX() -NAME_IDX = torch.ops.tensorrt.NAME_IDX() -DEVICE_IDX = torch.ops.tensorrt.DEVICE_IDX() -ENGINE_IDX = torch.ops.tensorrt.ENGINE_IDX() -INPUT_BINDING_NAMES_IDX = torch.ops.tensorrt.INPUT_BINDING_NAMES_IDX() -OUTPUT_BINDING_NAMES_IDX = torch.ops.tensorrt.OUTPUT_BINDING_NAMES_IDX() -HW_COMPATIBLE_IDX = torch.ops.tensorrt.HW_COMPATIBLE_IDX() -SERIALIZED_METADATA_IDX = torch.ops.tensorrt.SERIALIZED_METADATA_IDX() -SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() +ABI_TARGET_IDX = torch.ops.tensorrt.ABI_TARGET_IDX() # 0 +NAME_IDX = torch.ops.tensorrt.NAME_IDX() # 1 +DEVICE_IDX = torch.ops.tensorrt.DEVICE_IDX() # 2 +ENGINE_IDX = torch.ops.tensorrt.ENGINE_IDX() # 3 +INPUT_BINDING_NAMES_IDX = torch.ops.tensorrt.INPUT_BINDING_NAMES_IDX() # 4 +OUTPUT_BINDING_NAMES_IDX = torch.ops.tensorrt.OUTPUT_BINDING_NAMES_IDX() # 5 +HW_COMPATIBLE_IDX = torch.ops.tensorrt.HW_COMPATIBLE_IDX() # 6 +SERIALIZED_METADATA_IDX = torch.ops.tensorrt.SERIALIZED_METADATA_IDX() # 7 +SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 8 class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc] From 257db2619927cf6a174294283812e36c45bd2f06 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 27 Jun 2024 14:38:12 -0700 Subject: [PATCH 65/70] Fixed inline module check --- py/torch_tensorrt/dynamo/_refit.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 1567ef4614..15259127cf 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -160,9 +160,11 @@ def refit_module_weights( """ inline_module = False if isinstance(compiled_module, ExportedProgram): - inline_module = True compiled_module = compiled_module.module() + if len(list(compiled_module.named_children())) == 0: + inline_module = True + compiled_module = copy.deepcopy(compiled_module) # Get the settings and check the setting to be uniform From fef67668a40f0591182e4d92ec565823acd1804e Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 28 Jun 2024 13:33:07 -0700 Subject: [PATCH 66/70] Added deprecate warning. Renamed refit flag to make_refitable --- examples/dynamo/refit_engine_example.py | 2 +- py/torch_tensorrt/dynamo/_compiler.py | 21 +++++++++++++++---- py/torch_tensorrt/dynamo/_defaults.py | 2 +- py/torch_tensorrt/dynamo/_refit.py | 4 ++-- py/torch_tensorrt/dynamo/_settings.py | 4 ++-- .../dynamo/conversion/_TRTInterpreter.py | 2 +- tests/py/dynamo/models/test_model_refit.py | 12 +++++------ 7 files changed, 30 insertions(+), 17 deletions(-) diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index 36d8353a75..c841c5f57a 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -55,7 +55,7 @@ debug=debug, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, - refit=True, + make_refitable=True, ) # Output is a torch.fx.GraphModule # Save the graph module as an exported program diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 32b0ca65d7..456b546447 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -52,7 +52,7 @@ def compile( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] ) = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - refit: bool = _defaults.REFIT, + make_refitable: bool = _defaults.MAKE_REFITABLE, debug: bool = _defaults.DEBUG, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, @@ -152,6 +152,13 @@ def compile( stacklevel=2, ) + if "refit" in kwargs.keys(): + warnings.warn( + "Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.", + DeprecationWarning, + stacklevel=2, + ) + engine_capability = EngineCapability._from(engine_capability) if torch_executed_modules is not None and torch_executed_modules: @@ -206,7 +213,7 @@ def compile( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, - "refit": refit, + "make_refitable": make_refitable, "engine_capability": engine_capability, "dla_sram_size": dla_sram_size, "dla_local_dram_size": dla_local_dram_size, @@ -458,7 +465,7 @@ def convert_module_to_trt_engine( require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, disable_tf32: bool = _defaults.DISABLE_TF32, sparse_weights: bool = _defaults.SPARSE_WEIGHTS, - refit: bool = _defaults.REFIT, + make_refitable: bool = _defaults.MAKE_REFITABLE, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, dla_sram_size: int = _defaults.DLA_SRAM_SIZE, @@ -540,6 +547,12 @@ def convert_module_to_trt_engine( DeprecationWarning, stacklevel=2, ) + if "refit" in kwargs.keys(): + warnings.warn( + "Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.", + DeprecationWarning, + stacklevel=2, + ) input_list = list(inputs) if inputs is not None else [] torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set() @@ -567,7 +580,7 @@ def convert_module_to_trt_engine( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, - "refit": refit, + "make_refitable": make_refitable, "engine_capability": engine_capability, "num_avg_timing_iters": num_avg_timing_iters, "dla_sram_size": dla_sram_size, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 7931dc865c..356eecc01f 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -22,7 +22,7 @@ USE_PYTHON_RUNTIME = False USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False -REFIT = False +MAKE_REFITABLE = False REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 15259127cf..df8c4f177d 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -182,7 +182,7 @@ def refit_module_weights( ] assert ( encoded_settings != "" - ), "Settings are not saved in the engine. Please recompile the engine with refit=True." + ), "Settings are not saved in the engine. Please recompile the engine with make_refitable=True." settings = TorchTensorRTModule.decode_metadata(encoded_settings) # Handle torch modules compiled_submodules_map = dict(compiled_submodules) @@ -198,7 +198,7 @@ def refit_module_weights( settings = submodule.settings assert ( - settings.refit + settings.make_refitable ), "Refitting is not enabled. Please recompile the engine with refit=True." if settings.debug: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 9592bc1fd5..4ca1d91726 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -15,12 +15,12 @@ ENABLED_PRECISIONS, ENGINE_CAPABILITY, HARDWARE_COMPATIBLE, + MAKE_REFITABLE, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, - REFIT, REQUIRE_FULL_COMPILATION, SPARSE_WEIGHTS, TRUNCATE_DOUBLE, @@ -88,7 +88,7 @@ class CompilationSettings: require_full_compilation: bool = REQUIRE_FULL_COMPILATION disable_tf32: bool = DISABLE_TF32 sparse_weights: bool = SPARSE_WEIGHTS - refit: bool = REFIT + make_refitable: bool = MAKE_REFITABLE engine_capability: EngineCapability = field( default_factory=lambda: ENGINE_CAPABILITY ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index d3a2c00e13..785f148398 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -251,7 +251,7 @@ def _populate_trt_builder_config( if self.compilation_settings.disable_tf32: builder_config.clear_flag(trt.BuilderFlag.TF32) - if self.compilation_settings.refit: + if self.compilation_settings.make_refitable: builder_config.set_flag(trt.BuilderFlag.REFIT) if strict_type_constraints: diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index efa6b1fb03..36297f2f30 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -50,7 +50,7 @@ def test_mapping(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - refit=True, + make_refitable=True, ) settings = trt_gm._run_on_acc_0.settings runtime = trt.Runtime(TRT_LOGGER) @@ -97,7 +97,7 @@ def test_refit_one_engine(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - refit=True, + make_refitable=True, ) new_trt_gm = refit_module_weights( @@ -143,7 +143,7 @@ def test_refit_one_engine_bert(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - refit=True, + make_refitable=True, ) new_trt_gm = refit_module_weights( @@ -191,7 +191,7 @@ def test_refit_one_engine_inline_runtime(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - refit=True, + make_refitable=True, ) torchtrt.save(trt_gm, trt_ep_path, inputs=inputs) trt_gm = torch.export.load(trt_ep_path) @@ -236,7 +236,7 @@ def test_refit_one_engine_python_runtime(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - refit=True, + make_refitable=True, ) new_trt_gm = refit_module_weights( @@ -301,7 +301,7 @@ def forward(self, x): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - refit=True, + make_refitable=True, torch_executed_ops=torch_executed_ops, ) From 7381221805d1e0fc0bd5684af1fbb2eab0bb4b7c Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 1 Jul 2024 11:03:42 -0700 Subject: [PATCH 67/70] Updated lowering process to conform with latest main branch --- py/torch_tensorrt/dynamo/_refit.py | 13 +++++++++---- tests/py/dynamo/models/test_model_refit.py | 9 +++++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index df8c4f177d..38810e59b3 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -20,7 +20,11 @@ ) from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter from torch_tensorrt.dynamo.conversion.truncate_double import repair_double_inputs -from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( PythonTorchTensorRTModule, ) @@ -210,19 +214,21 @@ def refit_module_weights( # Prepare torch_trt inputs inputs = prepare_inputs(inputs) device = to_torch_tensorrt_device(settings.device) + torch_inputs = get_torch_inputs(inputs, device) runtime = trt.Runtime(TRT_LOGGER) if not isinstance(new_weight_module, ExportedProgram): raise AssertionError( f"Input graph should be an ExportedProgram but got type {type(new_weight_module)}" ) + new_weight_module = pre_export_lowering(new_weight_module, torch_inputs) new_weight_module = new_weight_module.run_decompositions( get_decompositions(settings.enable_experimental_decompositions) ) new_gm = new_weight_module.module() logger.debug("Input graph: " + str(new_gm.graph)) # Apply lowering on the graph module - torch_inputs = get_torch_inputs(inputs, device) - new_gm = apply_lowering_passes(new_gm, torch_inputs) + + new_gm = post_lowering(new_gm, torch_inputs) logger.info("Compilation Settings: %s\n", settings) @@ -245,7 +251,6 @@ def refit_module_weights( exc_info=True, ) - fast_partitioner_failed = True settings.use_fast_partitioner = False if not settings.use_fast_partitioner: diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 36297f2f30..36999eb499 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -18,7 +18,11 @@ construct_refit_mapping, get_engine_from_encoded_engine, ) -from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) from torch_tensorrt.logging import TRT_LOGGER from transformers import BertModel @@ -58,11 +62,12 @@ def test_mapping(): engine_info = trt_gm._run_on_acc_0.engine.__getstate__()[0] engine = get_engine_from_encoded_engine(engine_info[3], runtime) + exp_program2 = pre_export_lowering(exp_program2, inputs) exp_program2 = exp_program2.run_decompositions( get_decompositions(settings.enable_experimental_decompositions) ) new_gm = exp_program2.module() - new_gm = apply_lowering_passes(new_gm, inputs) + new_gm = post_lowering(new_gm, inputs) mapping = construct_refit_mapping(new_gm, trt_input, settings) refitter = trt.Refitter(engine, TRT_LOGGER) From e7768f7cd67e06240d263aa5cb270d96fd67aced Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 1 Jul 2024 14:21:17 -0700 Subject: [PATCH 68/70] Handled default setting usecases --- py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py | 3 ++- py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index f217e4a9cf..3b4bf160a8 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -10,6 +10,7 @@ from torch.nn import Module from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo.runtime.tools import ( _is_switch_required, _select_rt_device, @@ -33,7 +34,7 @@ def __init__( engine: bytes, input_names: Optional[List[str]] = None, output_names: Optional[List[str]] = None, - settings: Any = None, + settings: CompilationSettings = CompilationSettings(), ): super(PythonTorchTensorRTModule, self).__init__() self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index d768ca11af..cb74921a03 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -8,6 +8,7 @@ import torch from torch_tensorrt._Device import Device +from torch_tensorrt.dynamo import CompilationSettings logger = logging.getLogger(__name__) @@ -55,7 +56,7 @@ def __init__( name: str = "", input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, - settings: Any = None, + settings: CompilationSettings = CompilationSettings(), ): """__init__ method for torch_tensorrt.dynamo.runtime._TorchTensorRTModule.TorchTensorRTModule From 51a03c9c691c5cd33dd3f15ba4d407b1dd393379 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 1 Jul 2024 14:48:04 -0700 Subject: [PATCH 69/70] Fixed circular import bugs --- py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py | 2 +- py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 3b4bf160a8..b5365bf208 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -10,7 +10,7 @@ from torch.nn import Module from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype -from torch_tensorrt.dynamo import CompilationSettings +from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.runtime.tools import ( _is_switch_required, _select_rt_device, diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index cb74921a03..1449d4ae36 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -8,7 +8,7 @@ import torch from torch_tensorrt._Device import Device -from torch_tensorrt.dynamo import CompilationSettings +from torch_tensorrt.dynamo._settings import CompilationSettings logger = logging.getLogger(__name__) From 33bde0f287b74b48c67e30a5e2c00cdc00b10be5 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 1 Jul 2024 17:14:18 -0700 Subject: [PATCH 70/70] Changed deprecated behavior --- py/torch_tensorrt/dynamo/_compiler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 5c4e741cbd..e854b04d42 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -166,6 +166,10 @@ def compile( DeprecationWarning, stacklevel=2, ) + if make_refitable: + raise ValueError("Use flag make_refitable only. Flag refit is deprecated.") + else: + make_refitable = kwargs["refit"] engine_capability = EngineCapability._from(engine_capability)