Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Unable to compile the model using torch tensorrt #1565

Open
IamExperimenting opened this issue Dec 27, 2022 · 10 comments
Open

🐛 [Bug] Unable to compile the model using torch tensorrt #1565

IamExperimenting opened this issue Dec 27, 2022 · 10 comments
Assignees
Labels
bug: triaged [verified] We can replicate the bug bug Something isn't working

Comments

@IamExperimenting
Copy link

IamExperimenting commented Dec 27, 2022

Bug Description

Hi team, I have built the object detection model using torchvision fasterrcnn model. I need to deploy this model in Nvidia Triton server, so I’m trying to compile the model using torch_tensorrt but its failing.

@narendasan @gs-olive

To Reproduce

Steps to reproduce the behavior:

import torch, tensorrt, torch_tensorrt,torchvision
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn().eval()

trt_module = torch_tensorrt.compile(model,
    inputs = [torch_tensorrt.Input((1, 3, 720, 1280))], # input shape   
    enabled_precisions = {torch.half} # Run with FP16
)
# save the TensorRT embedded Torchscript
torch.jit.save(trt_module, "trt_torchscript_module.ts")

Expected behavior

pytorch model should be compiled using torch_tensorrt library

Environment

OS : ubuntu 20.04
Python : 3.10.8

Build information about Torch-TensorRT can be found by turning on debug messages

tensorrt version : 8.5.2.2

  • Torch-TensorRT Version (e.g. 1.0.0): 1.3.0
  • PyTorch Version (e.g. 1.0): 1.13.1
  • CPU Architecture:
  • OS (e.g., Linux): Linux - ubuntu 20.4
  • How you installed PyTorch (conda, pip, libtorch, source): conda
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.10.8
  • CUDA version: 11.6
  • GPU models and configuration: no
  • Any other relevant information:

Additional context

** please find the error message below **

RuntimeError                              Traceback (most recent call last)
Cell In[3], line 1
----> 1 trt_module = torch_tensorrt.compile(model,
      2     inputs = [torch_tensorrt.Input((1, 3, 720, 1280))], # input shape   
      3     enabled_precisions = {torch.half} # Run with FP16
      4 )
      5 # save the TensorRT embedded Torchscript
      6 torch.jit.save(trt_module, "trt_torchscript_module.ts")

File ~/miniconda3/envs/tensorrt/lib/python3.10/site-packages/torch_tensorrt/_compile.py:125, in compile(module, ir, inputs, enabled_precisions, **kwargs)
    120         logging.log(
    121             logging.Level.Info,
    122             "Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
    123         )
    124         ts_mod = torch.jit.script(module)
--> 125     return torch_tensorrt.ts.compile(
    126         ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
    127     )
    128 elif target_ir == _IRType.fx:
    129     if (
    130         torch.float16 in enabled_precisions
    131         or torch_tensorrt.dtype.half in enabled_precisions
    132     ):
...

RuntimeError: 
temporary: the only valid use of a module is looking up an attribute but found  = prim::SetAttr[name="_has_warned"](%self, %self.backbone.body.1.use_res_conne
@IamExperimenting IamExperimenting added the bug Something isn't working label Dec 27, 2022
@peri044
Copy link
Collaborator

peri044 commented Dec 28, 2022

Can you enable debug logging and provide the full log?

with torch_tensorrt.logging.debug():
    trt_module = torch_tensorrt.compile(model,
        inputs = [torch_tensorrt.Input((1, 3, 720, 1280))], # input shape   
        enabled_precisions = {torch.half} # Run with FP16
    )

@IamExperimenting
Copy link
Author

IamExperimenting commented Dec 28, 2022

@peri044

code

with torch_tensorrt.logging.debug():
    trt_module = torch_tensorrt.compile(model,
        inputs = [torch_tensorrt.Input((1, 3, 720, 1280))], # input shape   
        enabled_precisions = {torch.half}, # Run with FP16
        debug =True
    )
# save the TensorRT embedded Torchscript
torch.jit.save(trt_module, "trt_torchscript_module.ts")

please find the error log below,

INFO: [Torch-TensorRT] - ir was set to default, using TorchScript as ir
INFO: [Torch-TensorRT] - Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript
DEBUG: [Torch-TensorRT] - TensorRT Compile Spec: {
    "Inputs": [
Input(shape=(1,3,720,1280,), dtype=Unknown data type, format=Contiguous/Linear/NCHW)    ]
    "Enabled Precision": [Half, ]
    "TF32 Disabled": 0
    "Sparsity": 0
    "Refit": 0
    "Debug": 1
    "Device":  {
        "device_type": GPU
        "allow_gpu_fallback": False
        "gpu_id": 0
        "dla_core": -1
    }

    "Engine Capability": Default
    "Num Avg Timing Iters": 1
    "Workspace Size": 0
    "DLA SRAM Size": 1048576
    "DLA Local DRAM Size": 1073741824
    "DLA Global DRAM Size": 536870912
    "Truncate long and double": 0
    "Torch Fallback":  {
        "enabled": True
        "min_block_size": 3
        "forced_fallback_operators": [
        ]
        "forced_fallback_modules": [
        ]
    }
}
DEBUG: [Torch-TensorRT] - init_compile_spec with input vector
DEBUG: [Torch-TensorRT] - Settings requested for Lowering:
    torch_executed_modules: [
    ]

{
	"name": "RuntimeError",
	"message": "\ntemporary: the only valid use of a module is looking up an attribute but found  = prim::SetAttr[name=\"_has_warned\"](%self, %self.backbone.body.1.use_res_connect)\n:\n",
	"stack": "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)\nCell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[39mwith\u001b[39;00m torch_tensorrt\u001b[39m.\u001b[39mlogging\u001b[39m.\u001b[39mdebug():\n\u001b[0;32m----> 2\u001b[0m     trt_module \u001b[39m=\u001b[39m torch_tensorrt\u001b[39m.\u001b[39;49mcompile(model,\n\u001b[1;32m      3\u001b[0m         inputs \u001b[39m=\u001b[39;49m [torch_tensorrt\u001b[39m.\u001b[39;49mInput((\u001b[39m1\u001b[39;49m, \u001b[39m3\u001b[39;49m, \u001b[39m720\u001b[39;49m, \u001b[39m1280\u001b[39;49m))], \u001b[39m# input shape   \u001b[39;49;00m\n\u001b[1;32m      4\u001b[0m         enabled_precisions \u001b[39m=\u001b[39;49m {torch\u001b[39m.\u001b[39;49mhalf}, \u001b[39m# Run with FP16\u001b[39;49;00m\n\u001b[1;32m      5\u001b[0m         debug \u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m\n\u001b[1;32m      6\u001b[0m     )\n\u001b[1;32m      7\u001b[0m \u001b[39m# save the TensorRT embedded Torchscript\u001b[39;00m\n\u001b[1;32m      8\u001b[0m torch\u001b[39m.\u001b[39mjit\u001b[39m.\u001b[39msave(trt_module, \u001b[39m\"\u001b[39m\u001b[39mtrt_torchscript_module.ts\u001b[39m\u001b[39m\"\u001b[39m)\n\nFile \u001b[0;32m~/miniconda3/envs/tensorrt/lib/python3.10/site-packages/torch_tensorrt/_compile.py:125\u001b[0m, in \u001b[0;36mcompile\u001b[0;34m(module, ir, inputs, enabled_precisions, **kwargs)\u001b[0m\n\u001b[1;32m    120\u001b[0m         logging\u001b[39m.\u001b[39mlog(\n\u001b[1;32m    121\u001b[0m             logging\u001b[39m.\u001b[39mLevel\u001b[39m.\u001b[39mInfo,\n\u001b[1;32m    122\u001b[0m             \u001b[39m\"\u001b[39m\u001b[39mModule was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m    123\u001b[0m         )\n\u001b[1;32m    124\u001b[0m         ts_mod \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mjit\u001b[39m.\u001b[39mscript(module)\n\u001b[0;32m--> 125\u001b[0m     \u001b[39mreturn\u001b[39;00m torch_tensorrt\u001b[39m.\u001b[39;49mts\u001b[39m.\u001b[39;49mcompile(\n\u001b[1;32m    126\u001b[0m         ts_mod, inputs\u001b[39m=\u001b[39;49minputs, enabled_precisions\u001b[39m=\u001b[39;49menabled_precisions, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs\n\u001b[1;32m    127\u001b[0m     )\n\u001b[1;32m    128\u001b[0m \u001b[39melif\u001b[39;00m target_ir \u001b[39m==\u001b[39m _IRType\u001b[39m.\u001b[39mfx:\n\u001b[1;32m    129\u001b[0m     \u001b[39mif\u001b[39;00m (\n\u001b[1;32m    130\u001b[0m         torch\u001b[39m.\u001b[39mfloat16 \u001b[39min\u001b[39;00m enabled_precisions\n\u001b[1;32m    131\u001b[0m         \u001b[39mor\u001b[39;00m torch_tensorrt\u001b[39m.\u001b[39mdtype\u001b[39m.\u001b[39mhalf \u001b[39min\u001b[39;00m enabled_precisions\n\u001b[1;32m    132\u001b[0m     ):\n\nFile \u001b[0;32m~/miniconda3/envs/tensorrt/lib/python3.10/site-packages/torch_tensorrt/ts/_compiler.py:136\u001b[0m, in \u001b[0;36mcompile\u001b[0;34m(module, inputs, input_signature, device, disable_tf32, sparse_weights, enabled_precisions, refit, debug, capability, num_avg_timing_iters, workspace_size, dla_sram_size, dla_local_dram_size, dla_global_dram_size, calibrator, truncate_long_and_double, require_full_compilation, min_block_size, torch_executed_ops, torch_executed_modules)\u001b[0m\n\u001b[1;32m    110\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m    111\u001b[0m         \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mrequire_full_compilation is enabled however the list of modules and ops to run in torch is not empty. Found: torch_executed_ops: \u001b[39m\u001b[39m{\u001b[39;00mtorch_executed_ops\u001b[39m}\u001b[39;00m\u001b[39m, torch_executed_modules: \u001b[39m\u001b[39m{\u001b[39;00mtorch_executed_modules\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m    112\u001b[0m     )\n\u001b[1;32m    114\u001b[0m spec \u001b[39m=\u001b[39m {\n\u001b[1;32m    115\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39minputs\u001b[39m\u001b[39m\"\u001b[39m: inputs,\n\u001b[1;32m    116\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39minput_signature\u001b[39m\u001b[39m\"\u001b[39m: input_signature,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    133\u001b[0m     },\n\u001b[1;32m    134\u001b[0m }\n\u001b[0;32m--> 136\u001b[0m compiled_cpp_mod \u001b[39m=\u001b[39m _C\u001b[39m.\u001b[39;49mcompile_graph(module\u001b[39m.\u001b[39;49m_c, _parse_compile_spec(spec))\n\u001b[1;32m    137\u001b[0m compiled_module \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mjit\u001b[39m.\u001b[39m_recursive\u001b[39m.\u001b[39mwrap_cpp_module(compiled_cpp_mod)\n\u001b[1;32m    138\u001b[0m \u001b[39mreturn\u001b[39;00m compiled_module\n\n\u001b[0;31mRuntimeError\u001b[0m: \ntemporary: the only valid use of a module is looking up an attribute but found  = prim::SetAttr[name=\"_has_warned\"](%self, %self.backbone.body.1.use_res_connect)\n:\n"
}

hope this will help you.

@IamExperimenting
Copy link
Author

@peri044 did you get a chance to look into this?

@gs-olive
Copy link
Collaborator

gs-olive commented Jan 11, 2023

I looked into the issue and I think the error could be related to the behavior of this model when scripted/traced. Since the model is passed in as an nn.Module, it is automatically scripted here:

if module_type == _ModuleType.nn:
logging.log(
logging.Level.Info,
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
)
ts_mod = torch.jit.script(module)
return torch_tensorrt.ts.compile(
ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
)

When I run scripted_model = torch.jit.script(model), and then call the scripted model on a tensor of shape (1, 3, 720, 1080), TorchScript throws an error, as it seems to expect a list of (C, H, W) images as input. Additionally, the output type of this model appears to be a Python dictionary, which may also be contributing to the issue. Will update with any further findings/workarounds.

@IamExperimenting
Copy link
Author

IamExperimenting commented Jan 22, 2023

@peri044 @gs-olive , I can see that compile function script model and compiles it according the precision. However, I'm using pytorch inbuilt model here model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn().eval(). it is not working. And you have mentioned its output is diction datatype, if it is dictionary output how do I compile this model?

@gs-olive
Copy link
Collaborator

gs-olive commented Jan 24, 2023

Thanks for the update. Upon further investigation, it seems that the dictionary output is not the root cause of the issue. The error occurs here, on line 172:

LOG_GRAPH("LibTorch Lowering");
auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue());

The Torch lowering code throws an error, shown here, because the model code itself sets class attributes from within the forward function, as shown in this snippet from the MobileNet V3 model code. Will update with any workarounds that make the compilation functional

@gs-olive
Copy link
Collaborator

gs-olive commented Mar 6, 2023

As an update on this issue, we are investigating the FX path for this model, and are addressing some failures with the model currently (see pytorch/pytorch#96151)

@github-actions
Copy link

github-actions bot commented Jun 5, 2023

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

@github-actions
Copy link

github-actions bot commented Sep 4, 2023

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

@gs-olive
Copy link
Collaborator

Hello - I have verified that this model is successfully compiling with our torch.compile backend, which can be invoked as follows:

import torch
import torch_tensorrt

...

optimized_model = torch.compile(detectron, backend="tensorrt", options={...})
optimized_model(*inputs)

The current version of fasterrcnn_mobilenet_v3_large_320_fpn, as tested has many graph breaks so it would not be exportable/traceable by Dynamo as-is, though Torch may provide utilities to assist with this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug: triaged [verified] We can replicate the bug bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants