diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index e706aec677..660cb8a875 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -375,7 +375,7 @@ def refit_module_weights( if not weight_name_map: use_weight_map_cache = False logger.warning( - "Fast refitting is not supported in this module. Use regular refitting." + "This engine does not have a weight map cache. Rebuilding the weight map" ) else: compiled_submodule = getattr(compiled_module, name) @@ -385,7 +385,7 @@ def refit_module_weights( weight_name_map = compiled_submodule.weight_name_map except AttributeError: logger.warning( - "The module was compiled wit an old version of Torch-TensorRT. Rebuilding the weight map." + "The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map." ) if not weight_name_map: use_weight_map_cache = False diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 28aa62072e..9a3cace599 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple import numpy as np +import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -29,7 +30,6 @@ from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -504,7 +504,9 @@ def run( engine_bytes.write(serialized_engine) engine_str = engine_bytes.getvalue() - return TRTInterpreterResult(engine_str, self._input_names, self._output_names, self.weight_name_map) + return TRTInterpreterResult( + engine_str, self._input_names, self._output_names, self.weight_name_map + ) def run_node(self, n: torch.fx.Node) -> torch.fx.Node: self._cur_node_name = get_node_name(n) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 8316704c6f..57fa1749bf 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -131,7 +131,9 @@ def convert_module( from torch_tensorrt.logging import TRT_LOGGER runtime = trt.Runtime(TRT_LOGGER) - refit_test_engine = runtime.deserialize_cuda_engine(interpreter_result.engine) + refit_test_engine = runtime.deserialize_cuda_engine( + interpreter_result.serialized_engine + ) weight_name_map: Any = None # Do the test refit with cached map if make_refitable is enabled if settings.make_refitable: @@ -169,5 +171,5 @@ def convert_module( output_binding_names=list(interpreter_result.output_names), name=name, settings=settings, - weight_name_map = weight_name_map + weight_name_map=weight_name_map, ) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index e21e83aaac..d5da83488a 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -4,6 +4,7 @@ from contextlib import nullcontext from typing import Any, Dict, List, Optional, Sequence, Tuple +import tensorrt as trt import torch import torch_tensorrt from torch.nn import Module @@ -18,8 +19,6 @@ from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -104,7 +103,6 @@ def __init__( self.settings = settings self.engine = None self.weight_name_map = weight_name_map - self._initialize() if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index d3216177db..fe3974ff96 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -145,8 +145,7 @@ def setup_engine(self) -> None: self.encode_metadata(metadata), ] ) - - + def encode_metadata(self, metadata: Any) -> str: metadata = copy.deepcopy(metadata) metadata["settings"].torch_executed_ops = { diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 4b994c6ad5..e803c7fad6 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -108,7 +108,7 @@ def test_fast_refit_one_engine(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=True, ) @@ -155,7 +155,7 @@ def test_fast_refit_one_engine_no_map(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=True, ) @@ -206,7 +206,7 @@ def test_fast_refit_one_engine_wrong_map(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=True, ) @@ -253,7 +253,7 @@ def test_fast_refit_one_engine_bert(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=True, ) @@ -303,7 +303,7 @@ def test_fast_refit_one_engine_inline_runtime(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=True, ) @@ -348,7 +348,7 @@ def test_fast_refit_one_engine_python_runtime(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=True, ) @@ -415,7 +415,7 @@ def forward(self, x): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=True, ) @@ -460,7 +460,7 @@ def test_refit_one_engine(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=False, ) @@ -507,7 +507,7 @@ def test_refit_one_engine_bert(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=False, ) @@ -557,7 +557,7 @@ def test_refit_one_engine_inline_runtime(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=False, ) @@ -602,7 +602,7 @@ def test_refit_one_engine_python_runtime(): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=False, ) @@ -669,7 +669,7 @@ def forward(self, x): new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, - inputs=inputs, + arg_inputs=inputs, use_weight_map_cache=False, )