Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch_tensorrt.dynamo.compile saved exported programs cannot be loaded #3108

Closed
kacper-kleczewski opened this issue Aug 21, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@kacper-kleczewski
Copy link

Bug Description

Models exported with torch.export.export, saved, loaded and then compiled with torch_tensorrt.dynamo.compile cannot be loaded with torch.export.load with error:

W0820 14:11:53.628000 139673176707712 torch/fx/experimental/symbolic_shapes.py:4424] s0 is not in var_ranges, defaulting to unknown range.
E0820 14:11:53.628000 139673176707712 torch/fx/experimental/recording.py:280] failed while running evaluate_expr(*(s0 >= 0, True), **{'fx_node': None})

To Reproduce

Code below should reproduce issue. It can be also observed with more complex model like EfficientNet.

import torch
import troch_tensorrt

model = torch.nn.Linear(5, 7).eval()
sample = torch.randn(3, 5)

ep = torch.export.export(model, sample)
torch.export.save(ep, "model.ep")

ep_loaded = torch.export.load("model.ep")
compiled = torch_tensorrt.dynamo.compile(ep_loaded, [sample])

torch_tensorrt.save(compiled, "model_compiled.ep", inputs=[sample])

loaded_torch_tensorrt = torch.export.load("model_compiled.ep")

Expected behavior

Succesfull loading of the model.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

Nvidia PyTorch container 24.07

@kacper-kleczewski kacper-kleczewski added the bug Something isn't working label Aug 21, 2024
@peri044
Copy link
Collaborator

peri044 commented Aug 23, 2024

Hello @kacper-kleczewski , I tried your script with the main branch (which is on 2.5.0.dev20240822+cu124) and it works fine. Here's the slightly modified script that I tried

import torch
import torch_tensorrt

model = torch.nn.Linear(5, 7).eval().cuda()
sample = torch.randn(3, 5).cuda()
pyt_out = model(sample)
ep = torch.export.export(model, (sample,))
torch.export.save(ep, "model.ep")

ep_loaded = torch.export.load("model.ep")
compiled = torch_tensorrt.dynamo.compile(ep_loaded, [sample], min_block_size=1)

torch_tensorrt.save(compiled, "model_compiled.ep", inputs=[sample])

loaded_torch_tensorrt = torch.export.load("model_compiled.ep")
trt_gm = loaded_torch_tensorrt.module()
trt_out = trt_gm(sample)

print("Diff: ", torch.mean(torch.abs(pyt_out-trt_out)))

I remember we had some serialization issues with 2.4 version of torch which were resolved recently.

@peri044
Copy link
Collaborator

peri044 commented Dec 12, 2024

The fixes are already in main. Closing this now. Feel free to re-open if you encounter this.

@peri044 peri044 closed this as completed Dec 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants