Skip to content

Commit

Permalink
allow direct method setting to support custom layers (#460)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaybdub authored Dec 2, 2020
1 parent adccbf1 commit 81024cc
Showing 1 changed file with 46 additions and 22 deletions.
68 changes: 46 additions & 22 deletions torch2trt/torch2trt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
import tensorrt as trt
from copy import copy
import copy
import numpy as np
import io
from collections import defaultdict
import importlib

from .calibration import (
TensorBatchDataset,
Expand Down Expand Up @@ -297,30 +298,24 @@ def wrapper(*args, **kwargs):
class ConversionHook(object):
"""Attaches TensorRT converter to PyTorch method call"""

def __init__(self, ctx, method, converter):
def __init__(self, ctx, key, converter):
self.ctx = ctx
self.method_str = method
self.key = key
self.converter = converter

def _set_method(self, method):
exec("%s = method" % self.method_str)
module = self.converter['module']
exec('module.%s = method' % self.converter['qual_name'])

def __enter__(self):
try:
self.method_impl = eval(self.method_str)
except AttributeError:
self.method_impl = None

if self.method_impl:
self._set_method(
attach_converter(
self.ctx, self.method_impl, self.converter, self.method_str
)
self._set_method(
attach_converter(
self.ctx, self.converter['method_impl'], self.converter, self.converter['method_str']
)
)

def __exit__(self, type, val, tb):
if self.method_impl:
self._set_method(self.method_impl)
self._set_method(self.converter['method_impl'])

def default_input_names(num_inputs):
return ["input_%d" % i for i in range(num_inputs)]
Expand Down Expand Up @@ -369,8 +364,8 @@ def __init__(self, network, converters=CONVERTERS):
self.method_kwargs = None
self.method_return = None
self.hooks = [
ConversionHook(self, method, converter)
for method, converter in converters.items()
ConversionHook(self, key, converter)
for key, converter in converters.items()
]

def __enter__(self):
Expand Down Expand Up @@ -569,11 +564,40 @@ def torch2trt(module,

# DEFINE ALL CONVERSION FUNCTIONS

def get_module_qualname(name):
s = name.split('.')

for i in range(len(s)):
idx = len(s) - i - 1
modulename, qualname = ".".join(s[:idx]), ".".join(s[idx:])
try:
module = importlib.import_module(modulename)
return module, modulename, qualname
except:
pass

raise RuntimeError("Could not import module")


def tensorrt_converter(method, is_real=True, enabled=True):

def tensorrt_converter(method, is_real=True, enabled=True, imports=[]):

if isinstance(method, str):
module, module_name, qual_name = get_module_qualname(method)
else:
module, module_name, qual_name = importlib.import_module(method.__module__), method.__module__, method.__qualname__

method_impl = eval('copy.deepcopy(module.%s)' % qual_name)

def register_converter(converter):
CONVERTERS[method] = {"converter": converter, "is_real": is_real}
CONVERTERS[method] = {
"converter": converter,
"is_real": is_real,
"module": module,
"module_name": module_name,
"qual_name": qual_name,
"method_str": module_name + '.' + qual_name,
"method_impl": method_impl
}
return converter

def pass_converter(converter):
Expand All @@ -584,4 +608,4 @@ def pass_converter(converter):
else:
return pass_converter

return register_converter
return register_converter

0 comments on commit 81024cc

Please sign in to comment.