Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow direct method setting to support custom layers #460

Merged
merged 1 commit into from
Dec 2, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions torch2trt/torch2trt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
import tensorrt as trt
from copy import copy
import copy
import numpy as np
import io
from collections import defaultdict
import importlib

from .calibration import (
TensorBatchDataset,
Expand Down Expand Up @@ -297,30 +298,24 @@ def wrapper(*args, **kwargs):
class ConversionHook(object):
"""Attaches TensorRT converter to PyTorch method call"""

def __init__(self, ctx, method, converter):
def __init__(self, ctx, key, converter):
self.ctx = ctx
self.method_str = method
self.key = key
self.converter = converter

def _set_method(self, method):
exec("%s = method" % self.method_str)
module = self.converter['module']
exec('module.%s = method' % self.converter['qual_name'])

def __enter__(self):
try:
self.method_impl = eval(self.method_str)
except AttributeError:
self.method_impl = None

if self.method_impl:
self._set_method(
attach_converter(
self.ctx, self.method_impl, self.converter, self.method_str
)
self._set_method(
attach_converter(
self.ctx, self.converter['method_impl'], self.converter, self.converter['method_str']
)
)

def __exit__(self, type, val, tb):
if self.method_impl:
self._set_method(self.method_impl)
self._set_method(self.converter['method_impl'])

def default_input_names(num_inputs):
return ["input_%d" % i for i in range(num_inputs)]
Expand Down Expand Up @@ -369,8 +364,8 @@ def __init__(self, network, converters=CONVERTERS):
self.method_kwargs = None
self.method_return = None
self.hooks = [
ConversionHook(self, method, converter)
for method, converter in converters.items()
ConversionHook(self, key, converter)
for key, converter in converters.items()
]

def __enter__(self):
Expand Down Expand Up @@ -569,11 +564,40 @@ def torch2trt(module,

# DEFINE ALL CONVERSION FUNCTIONS

def get_module_qualname(name):
s = name.split('.')

for i in range(len(s)):
idx = len(s) - i - 1
modulename, qualname = ".".join(s[:idx]), ".".join(s[idx:])
try:
module = importlib.import_module(modulename)
return module, modulename, qualname
except:
pass

raise RuntimeError("Could not import module")


def tensorrt_converter(method, is_real=True, enabled=True):

def tensorrt_converter(method, is_real=True, enabled=True, imports=[]):

if isinstance(method, str):
module, module_name, qual_name = get_module_qualname(method)
else:
module, module_name, qual_name = importlib.import_module(method.__module__), method.__module__, method.__qualname__

method_impl = eval('copy.deepcopy(module.%s)' % qual_name)

def register_converter(converter):
CONVERTERS[method] = {"converter": converter, "is_real": is_real}
CONVERTERS[method] = {
"converter": converter,
"is_real": is_real,
"module": module,
"module_name": module_name,
"qual_name": qual_name,
"method_str": module_name + '.' + qual_name,
"method_impl": method_impl
}
return converter

def pass_converter(converter):
Expand All @@ -584,4 +608,4 @@ def pass_converter(converter):
else:
return pass_converter

return register_converter
return register_converter