Skip to content

Commit

Permalink
Revert "allow direct method setting to support custom layers (#460)"
Browse files Browse the repository at this point in the history
This reverts commit 81024cc.
  • Loading branch information
jaybdub authored Dec 2, 2020
1 parent 81024cc commit bde4c86
Showing 1 changed file with 22 additions and 46 deletions.
68 changes: 22 additions & 46 deletions torch2trt/torch2trt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import torch
import tensorrt as trt
import copy
from copy import copy
import numpy as np
import io
from collections import defaultdict
import importlib

from .calibration import (
TensorBatchDataset,
Expand Down Expand Up @@ -298,24 +297,30 @@ def wrapper(*args, **kwargs):
class ConversionHook(object):
"""Attaches TensorRT converter to PyTorch method call"""

def __init__(self, ctx, key, converter):
def __init__(self, ctx, method, converter):
self.ctx = ctx
self.key = key
self.method_str = method
self.converter = converter

def _set_method(self, method):
module = self.converter['module']
exec('module.%s = method' % self.converter['qual_name'])
exec("%s = method" % self.method_str)

def __enter__(self):
self._set_method(
attach_converter(
self.ctx, self.converter['method_impl'], self.converter, self.converter['method_str']
try:
self.method_impl = eval(self.method_str)
except AttributeError:
self.method_impl = None

if self.method_impl:
self._set_method(
attach_converter(
self.ctx, self.method_impl, self.converter, self.method_str
)
)
)

def __exit__(self, type, val, tb):
self._set_method(self.converter['method_impl'])
if self.method_impl:
self._set_method(self.method_impl)

def default_input_names(num_inputs):
return ["input_%d" % i for i in range(num_inputs)]
Expand Down Expand Up @@ -364,8 +369,8 @@ def __init__(self, network, converters=CONVERTERS):
self.method_kwargs = None
self.method_return = None
self.hooks = [
ConversionHook(self, key, converter)
for key, converter in converters.items()
ConversionHook(self, method, converter)
for method, converter in converters.items()
]

def __enter__(self):
Expand Down Expand Up @@ -564,40 +569,11 @@ def torch2trt(module,

# DEFINE ALL CONVERSION FUNCTIONS

def get_module_qualname(name):
s = name.split('.')

for i in range(len(s)):
idx = len(s) - i - 1
modulename, qualname = ".".join(s[:idx]), ".".join(s[idx:])
try:
module = importlib.import_module(modulename)
return module, modulename, qualname
except:
pass

raise RuntimeError("Could not import module")


def tensorrt_converter(method, is_real=True, enabled=True, imports=[]):

if isinstance(method, str):
module, module_name, qual_name = get_module_qualname(method)
else:
module, module_name, qual_name = importlib.import_module(method.__module__), method.__module__, method.__qualname__

method_impl = eval('copy.deepcopy(module.%s)' % qual_name)

def tensorrt_converter(method, is_real=True, enabled=True):

def register_converter(converter):
CONVERTERS[method] = {
"converter": converter,
"is_real": is_real,
"module": module,
"module_name": module_name,
"qual_name": qual_name,
"method_str": module_name + '.' + qual_name,
"method_impl": method_impl
}
CONVERTERS[method] = {"converter": converter, "is_real": is_real}
return converter

def pass_converter(converter):
Expand All @@ -608,4 +584,4 @@ def pass_converter(converter):
else:
return pass_converter

return register_converter
return register_converter

0 comments on commit bde4c86

Please sign in to comment.