Skip to content

Commit

Permalink
Fixed bugs after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang committed Aug 8, 2024
1 parent 6f3142b commit 6588edb
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 23 deletions.
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def refit_module_weights(
if not weight_name_map:
use_weight_map_cache = False
logger.warning(
"Fast refitting is not supported in this module. Use regular refitting."
"This engine does not have a weight map cache. Rebuilding the weight map"
)
else:
compiled_submodule = getattr(compiled_module, name)
Expand All @@ -385,7 +385,7 @@ def refit_module_weights(
weight_name_map = compiled_submodule.weight_name_map
except AttributeError:
logger.warning(
"The module was compiled wit an old version of Torch-TensorRT. Rebuilding the weight map."
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
)
if not weight_name_map:
use_weight_map_cache = False
Expand Down
6 changes: 4 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple

import numpy as np
import tensorrt as trt
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
Expand All @@ -29,7 +30,6 @@
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt
from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -504,7 +504,9 @@ def run(
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()

return TRTInterpreterResult(engine_str, self._input_names, self._output_names, self.weight_name_map)
return TRTInterpreterResult(
engine_str, self._input_names, self._output_names, self.weight_name_map
)

def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
self._cur_node_name = get_node_name(n)
Expand Down
6 changes: 4 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def convert_module(
from torch_tensorrt.logging import TRT_LOGGER

runtime = trt.Runtime(TRT_LOGGER)
refit_test_engine = runtime.deserialize_cuda_engine(interpreter_result.engine)
refit_test_engine = runtime.deserialize_cuda_engine(
interpreter_result.serialized_engine
)
weight_name_map: Any = None
# Do the test refit with cached map if make_refitable is enabled
if settings.make_refitable:
Expand Down Expand Up @@ -169,5 +171,5 @@ def convert_module(
output_binding_names=list(interpreter_result.output_names),
name=name,
settings=settings,
weight_name_map = weight_name_map
weight_name_map=weight_name_map,
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from contextlib import nullcontext
from typing import Any, Dict, List, Optional, Sequence, Tuple

import tensorrt as trt
import torch
import torch_tensorrt
from torch.nn import Module
Expand All @@ -18,8 +19,6 @@
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -104,7 +103,6 @@ def __init__(
self.settings = settings
self.engine = None
self.weight_name_map = weight_name_map
self._initialize()

if self.serialized_engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()
Expand Down
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ def setup_engine(self) -> None:
self.encode_metadata(metadata),
]
)



def encode_metadata(self, metadata: Any) -> str:
metadata = copy.deepcopy(metadata)
metadata["settings"].torch_executed_ops = {
Expand Down
24 changes: 12 additions & 12 deletions tests/py/dynamo/models/test_model_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_fast_refit_one_engine():
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
arg_inputs=inputs,
use_weight_map_cache=True,
)

Expand Down Expand Up @@ -155,7 +155,7 @@ def test_fast_refit_one_engine_no_map():
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
arg_inputs=inputs,
use_weight_map_cache=True,
)

Expand Down Expand Up @@ -206,7 +206,7 @@ def test_fast_refit_one_engine_wrong_map():
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
arg_inputs=inputs,
use_weight_map_cache=True,
)

Expand Down Expand Up @@ -253,7 +253,7 @@ def test_fast_refit_one_engine_bert():
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
arg_inputs=inputs,
use_weight_map_cache=True,
)

Expand Down Expand Up @@ -303,7 +303,7 @@ def test_fast_refit_one_engine_inline_runtime():
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
arg_inputs=inputs,
use_weight_map_cache=True,
)

Expand Down Expand Up @@ -348,7 +348,7 @@ def test_fast_refit_one_engine_python_runtime():
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
arg_inputs=inputs,
use_weight_map_cache=True,
)

Expand Down Expand Up @@ -415,7 +415,7 @@ def forward(self, x):
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
arg_inputs=inputs,
use_weight_map_cache=True,
)

Expand Down Expand Up @@ -460,7 +460,7 @@ def test_refit_one_engine():
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
arg_inputs=inputs,
use_weight_map_cache=False,
)

Expand Down Expand Up @@ -507,7 +507,7 @@ def test_refit_one_engine_bert():
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
arg_inputs=inputs,
use_weight_map_cache=False,
)

Expand Down Expand Up @@ -557,7 +557,7 @@ def test_refit_one_engine_inline_runtime():
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
arg_inputs=inputs,
use_weight_map_cache=False,
)

Expand Down Expand Up @@ -602,7 +602,7 @@ def test_refit_one_engine_python_runtime():
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
arg_inputs=inputs,
use_weight_map_cache=False,
)

Expand Down Expand Up @@ -669,7 +669,7 @@ def forward(self, x):
new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
inputs=inputs,
arg_inputs=inputs,
use_weight_map_cache=False,
)

Expand Down

0 comments on commit 6588edb

Please sign in to comment.