Skip to content

Commit

Permalink
added expand converter (#487)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaybdub authored Jan 13, 2021
1 parent 2b1827e commit 033df0c
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Added

- Added converter for ``torch.Tensor.expand``
- Added support for custom converters for methods defined outside of ``torch`` module
- Added names for TensorRT layers
- Added GroupNorm plugin which internally uses PyTorch aten::group_norm
Expand Down
1 change: 1 addition & 0 deletions torch2trt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .clamp import *
from .compare import *
from .div import *
from .expand import *
from .getitem import *
from .identity import *
from .instance_norm import *
Expand Down
43 changes: 43 additions & 0 deletions torch2trt/converters/expand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.Tensor.expand')
def convert_expand(ctx):
input = ctx.method_args[0]
sizes = ctx.method_args[1:]
output = ctx.method_return

inshape = tuple(input.shape)[1:] # exclude batch
shape = tuple(output.shape)[1:]
ndim = len(shape)
start = tuple([0]*ndim)
stride = tuple([int(i == o) for i, o in zip(inshape, shape)]) # stride == 1 if dimensions match, 0 otherwise

layer = ctx.network.add_slice(input._trt, start, shape, stride)

output._trt = layer.get_output(0)


class ExpandModule(torch.nn.Module):
def __init__(self, *sizes):
super(ExpandModule, self).__init__()
self.sizes = sizes

def forward(self, x):
return x.expand(*self.sizes)


@add_module_test(torch.float32, torch.device('cuda'), [(1,1,3,3)])
def test_tensor_expand_singledim():
return ExpandModule(1, 3, 3, 3)


@add_module_test(torch.float32, torch.device('cuda'), [(1,1,1,3)])
def test_tensor_expand_multidim():
return ExpandModule(1, 3, 3, 3)


@add_module_test(torch.float32, torch.device('cuda'), [(1,1,1,3)])
def test_tensor_expand_inferdim():
return ExpandModule(1, 3, -1, -1)

0 comments on commit 033df0c

Please sign in to comment.