Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] AssertionError: to_numpy can only be called on None or a torch.Tensor, got: <tensorrt_bindings.tensorrt.ITensor #2061

Closed
peri044 opened this issue Jun 27, 2023 · 3 comments
Labels
bug Something isn't working No Activity

Comments

@peri044
Copy link
Collaborator

peri044 commented Jun 27, 2023

Bug Description

AssertionError: to_numpy can only be called on None or a torch.Tensor, got: <tensorrt_bindings.tensorrt.ITensor object at 0x7f72c6108d30> While executing %batch_norm

This is using new export workflow from https://github.com/pytorch/TensorRT/tree/dynamo_export_refactor branch.
The issue seems to be coming from partitioning (using from the torch.compile) workflow where all the constants are being registered as placeholders when a graph copy happens. Hence, constants like weights and biases are now treated as ITensors while the batch norm converter expects them to be constants.

To Reproduce

import torch
import torch_tensorrt
import torchvision.models as models
import timm

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=False)
        self.bn = torch.nn.BatchNorm2d(16)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        return out
 
model = MyModule().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")
compile_spec = {
    "inputs": [
        torch_tensorrt.Input(
            input.shape, dtype=torch.float, format=torch.contiguous_format
        )
    ],
    "enabled_precisions": {torch.float},
    "debug": True,
    "is_aten": True,
    "min_block_size": 1,
}

trt_mod = torch_tensorrt.dynamo.export.compile(model, **compile_spec)

Expected behavior

It should pass

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@peri044 peri044 added the bug Something isn't working label Jun 27, 2023
@gs-olive
Copy link
Collaborator

gs-olive commented Jun 27, 2023

A similar issue is fixed by #1955

@github-actions
Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

@gs-olive
Copy link
Collaborator

Fixed on main, with new get_trt_tensor system

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working No Activity
Projects
None yet
Development

No branches or pull requests

2 participants