Skip to content

Commit

Permalink
Fixed the issue in comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang committed Feb 21, 2025
1 parent ec2d674 commit a8e0b48
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 85 deletions.
47 changes: 44 additions & 3 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
2. Save a Mutable Torch TensorRT Module
3. Integration with Huggingface pipeline in LoRA use case
4. Usage of dynamic shape with Mutable Torch TensorRT Module
"""

import numpy as np
Expand Down Expand Up @@ -63,16 +64,14 @@
# Saving Mutable Torch TensorRT Module
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Currently, saving is only enabled for C++ runtime, not python runtime.
# Currently, saving is only when "use_python" = False in settings
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")

# %%
# Stable Diffusion with Huggingface
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# The LoRA checkpoint is from https://civitai.com/models/12597/moxin

from diffusers import DiffusionPipeline

with torch.no_grad():
Expand Down Expand Up @@ -111,3 +110,45 @@
# Refit triggered
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
image.save("./with_LoRA_mutable.jpg")


# %%
# Use Mutable Torch TensorRT module with dynamic shape
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b, c={}):
x = torch.matmul(a, b)
x = torch.matmul(c["a"], c["b"].T)
print(c["b"][0])
x = 2 * c["b"]
return x


device = "cuda:0"
model = Model().eval().to(device)
inputs = (torch.rand(10, 3).to(device), torch.rand(3, 30).to(device))
kwargs = {
"c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(10, 30).to(device)},
}
dim_0 = torch.export.Dim("dim", min=1, max=50)
dim_1 = torch.export.Dim("dim", min=1, max=50)
dim_2 = torch.export.Dim("dim2", min=1, max=50)
args_dynamic_shapes = ({1: dim_1}, {0: dim_0})
kwarg_dynamic_shapes = {
"c": {"a": {}, "b": {0: dim_2}},
}
# Export the model first with custom dynamic shape constraints
model = torch_trt.MutableTorchTensorRTModule(model, debug=True, min_block_size=1)
model.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes)
# Compile
model(*inputs, **kwargs)
# Change input shape
inputs_2 = (torch.rand(10, 5).to(device), torch.rand(10, 30).to(device))
kwargs_2 = {
"c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(5, 30).to(device)},
}
# Run without recompiling
model(*inputs_2, **kwargs_2)
9 changes: 6 additions & 3 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,12 @@ def refit_module_weights(
try:
weight_name_map = compiled_submodule.weight_name_map
except AttributeError:
logger.warning(
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
)
if not isinstance(
compiled_submodule, torch.fx.graph_module.GraphModule
):
logger.warning(
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
)
if not weight_name_map:
use_weight_map_cache = False
logger.warning(
Expand Down
19 changes: 12 additions & 7 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,10 @@ def _construct_trt_network_def(self) -> None:

@staticmethod
def find_weight(
weight_name: str, np_map: dict[str, Any], state_dict: dict[str, Any]
weight_name: str,
np_map: dict[str, Any],
state_dict: dict[str, Any],
device: torch.device,
) -> str:
"""
We need to build map from engine weight name to state_dict weight name.
Expand All @@ -385,19 +388,21 @@ def find_weight(
np_map: the map from weight name to np values in INetworkDefinition
state_dict: state of the graph module
"""
network_weight = torch.from_numpy(np_map[weight_name]).cuda()
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
for sd_w_name, sd_weight in state_dict.items():
if TRTInterpreter.check_weight_equal(sd_weight, network_weight):
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
del state_dict[sd_w_name]
return sd_w_name
return ""

@staticmethod
def check_weight_equal(
sd_weight: torch.tensor, network_weight: Union[torch.Tensor, np.ndarray]
sd_weight: torch.tensor,
network_weight: Union[torch.Tensor, np.ndarray],
device: torch.device,
) -> Any:
if not isinstance(network_weight, torch.Tensor):
network_weight = torch.from_numpy(network_weight).cuda()
network_weight = torch.from_numpy(network_weight).to(device)
try:
return sd_weight.shape == network_weight.shape and torch.all(
torch.abs(sd_weight - network_weight) < 0.01
Expand Down Expand Up @@ -530,10 +535,10 @@ def _save_weight_mapping(self) -> None:
# There is no direct connection in batch_norm layer. So skip it
pass
elif sd_weight_name not in sd or not TRTInterpreter.check_weight_equal(
sd[sd_weight_name], np_map[engine_weight_name]
sd[sd_weight_name], np_map[engine_weight_name], torch_device
):
weight_name_map[engine_weight_name] = TRTInterpreter.find_weight(
engine_weight_name, np_map, sd
engine_weight_name, np_map, sd, torch_device
)
if (
weight_name_map[engine_weight_name] != ""
Expand Down
Loading

0 comments on commit a8e0b48

Please sign in to comment.