Skip to content

Commit

Permalink
Add layer names (#432)
Browse files Browse the repository at this point in the history
* Auto-generate custom layer names

* fixed layer name count key

* updated changelog for adding layer names

Co-authored-by: Alex Sergeev <asergeev@maka-ars.com>
  • Loading branch information
jaybdub and alsrgv authored Oct 19, 2020
1 parent b0cc8e7 commit a9a6a53
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@

### Added

- Added names for TensorRT layers
- Replaced Tensor.ndim references with len(tensor.shape) to support older pytorch versions
- Added reduced precision documentation page
38 changes: 36 additions & 2 deletions torch2trt/torch2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from copy import copy
import numpy as np
import io
from collections import defaultdict

from .calibration import (
TensorBatchDataset,
Expand Down Expand Up @@ -326,10 +327,43 @@ def default_input_names(num_inputs):

def default_output_names(num_outputs):
return ["output_%d" % i for i in range(num_outputs)]



class LayerNamingNetworkWrapper(object):
def __init__(self, ctx, network):
self._ctx = ctx
self._network = network
self._layer_counts = defaultdict(lambda: 0)

def _set_layer_name(self, layer):
def arg_str(arg):
if isinstance(arg, torch.Tensor):
return "tensor(shape=%s, dtype=%s)" % (str(list(arg.shape)), str(arg.dtype))
return str(arg)

self._layer_counts[layer.type.name] += 1
args = [arg_str(arg) for arg in self._ctx.method_args]
kwargs = ["%s=%s" % (key, arg_str(arg)) for key, arg in self._ctx.method_kwargs.items()]
layer.name = "[%s #%d] %s(%s)" % (layer.type.name, self._layer_counts[layer.type.name],
self._ctx.method_str, ", ".join(args + kwargs))

def __getattr__(self, name):
attr = getattr(self._network, name)
if callable(attr):
def wrapper(*args, **kwargs):
ret = attr(*args, **kwargs)
if isinstance(ret, trt.ILayer):
self._set_layer_name(ret)
return ret

return wrapper
else:
return attr


class ConversionContext(object):
def __init__(self, network, converters=CONVERTERS):
self.network = network
self.network = LayerNamingNetworkWrapper(self, network)
self.lock = False
self.method_args = None
self.method_kwargs = None
Expand Down

0 comments on commit a9a6a53

Please sign in to comment.