diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index f5c9326f..7051e629 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -1,10 +1,9 @@ import torch import tensorrt as trt -import copy +from copy import copy import numpy as np import io from collections import defaultdict -import importlib from .calibration import ( TensorBatchDataset, @@ -298,24 +297,30 @@ def wrapper(*args, **kwargs): class ConversionHook(object): """Attaches TensorRT converter to PyTorch method call""" - def __init__(self, ctx, key, converter): + def __init__(self, ctx, method, converter): self.ctx = ctx - self.key = key + self.method_str = method self.converter = converter def _set_method(self, method): - module = self.converter['module'] - exec('module.%s = method' % self.converter['qual_name']) + exec("%s = method" % self.method_str) def __enter__(self): - self._set_method( - attach_converter( - self.ctx, self.converter['method_impl'], self.converter, self.converter['method_str'] + try: + self.method_impl = eval(self.method_str) + except AttributeError: + self.method_impl = None + + if self.method_impl: + self._set_method( + attach_converter( + self.ctx, self.method_impl, self.converter, self.method_str + ) ) - ) def __exit__(self, type, val, tb): - self._set_method(self.converter['method_impl']) + if self.method_impl: + self._set_method(self.method_impl) def default_input_names(num_inputs): return ["input_%d" % i for i in range(num_inputs)] @@ -364,8 +369,8 @@ def __init__(self, network, converters=CONVERTERS): self.method_kwargs = None self.method_return = None self.hooks = [ - ConversionHook(self, key, converter) - for key, converter in converters.items() + ConversionHook(self, method, converter) + for method, converter in converters.items() ] def __enter__(self): @@ -564,40 +569,11 @@ def torch2trt(module, # DEFINE ALL CONVERSION FUNCTIONS -def get_module_qualname(name): - s = name.split('.') - - for i in range(len(s)): - idx = len(s) - i - 1 - modulename, qualname = ".".join(s[:idx]), ".".join(s[idx:]) - try: - module = importlib.import_module(modulename) - return module, modulename, qualname - except: - pass - - raise RuntimeError("Could not import module") - -def tensorrt_converter(method, is_real=True, enabled=True, imports=[]): - - if isinstance(method, str): - module, module_name, qual_name = get_module_qualname(method) - else: - module, module_name, qual_name = importlib.import_module(method.__module__), method.__module__, method.__qualname__ - - method_impl = eval('copy.deepcopy(module.%s)' % qual_name) - +def tensorrt_converter(method, is_real=True, enabled=True): + def register_converter(converter): - CONVERTERS[method] = { - "converter": converter, - "is_real": is_real, - "module": module, - "module_name": module_name, - "qual_name": qual_name, - "method_str": module_name + '.' + qual_name, - "method_impl": method_impl - } + CONVERTERS[method] = {"converter": converter, "is_real": is_real} return converter def pass_converter(converter): @@ -608,4 +584,4 @@ def pass_converter(converter): else: return pass_converter - return register_converter \ No newline at end of file + return register_converter