From 2a23b504188b92eff2a5f98a65c047ba64ea7058 Mon Sep 17 00:00:00 2001 From: John Welsh Date: Tue, 1 Dec 2020 22:57:27 +0000 Subject: [PATCH] allow direct method setting to support custom layers --- torch2trt/torch2trt.py | 68 ++++++++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 22 deletions(-) diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index 7051e629..f5c9326f 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -1,9 +1,10 @@ import torch import tensorrt as trt -from copy import copy +import copy import numpy as np import io from collections import defaultdict +import importlib from .calibration import ( TensorBatchDataset, @@ -297,30 +298,24 @@ def wrapper(*args, **kwargs): class ConversionHook(object): """Attaches TensorRT converter to PyTorch method call""" - def __init__(self, ctx, method, converter): + def __init__(self, ctx, key, converter): self.ctx = ctx - self.method_str = method + self.key = key self.converter = converter def _set_method(self, method): - exec("%s = method" % self.method_str) + module = self.converter['module'] + exec('module.%s = method' % self.converter['qual_name']) def __enter__(self): - try: - self.method_impl = eval(self.method_str) - except AttributeError: - self.method_impl = None - - if self.method_impl: - self._set_method( - attach_converter( - self.ctx, self.method_impl, self.converter, self.method_str - ) + self._set_method( + attach_converter( + self.ctx, self.converter['method_impl'], self.converter, self.converter['method_str'] ) + ) def __exit__(self, type, val, tb): - if self.method_impl: - self._set_method(self.method_impl) + self._set_method(self.converter['method_impl']) def default_input_names(num_inputs): return ["input_%d" % i for i in range(num_inputs)] @@ -369,8 +364,8 @@ def __init__(self, network, converters=CONVERTERS): self.method_kwargs = None self.method_return = None self.hooks = [ - ConversionHook(self, method, converter) - for method, converter in converters.items() + ConversionHook(self, key, converter) + for key, converter in converters.items() ] def __enter__(self): @@ -569,11 +564,40 @@ def torch2trt(module, # DEFINE ALL CONVERSION FUNCTIONS +def get_module_qualname(name): + s = name.split('.') + + for i in range(len(s)): + idx = len(s) - i - 1 + modulename, qualname = ".".join(s[:idx]), ".".join(s[idx:]) + try: + module = importlib.import_module(modulename) + return module, modulename, qualname + except: + pass + + raise RuntimeError("Could not import module") + -def tensorrt_converter(method, is_real=True, enabled=True): - +def tensorrt_converter(method, is_real=True, enabled=True, imports=[]): + + if isinstance(method, str): + module, module_name, qual_name = get_module_qualname(method) + else: + module, module_name, qual_name = importlib.import_module(method.__module__), method.__module__, method.__qualname__ + + method_impl = eval('copy.deepcopy(module.%s)' % qual_name) + def register_converter(converter): - CONVERTERS[method] = {"converter": converter, "is_real": is_real} + CONVERTERS[method] = { + "converter": converter, + "is_real": is_real, + "module": module, + "module_name": module_name, + "qual_name": qual_name, + "method_str": module_name + '.' + qual_name, + "method_impl": method_impl + } return converter def pass_converter(converter): @@ -584,4 +608,4 @@ def pass_converter(converter): else: return pass_converter - return register_converter + return register_converter \ No newline at end of file