Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional converters for floordiv, mod, ne, and torch::tensor() operations #505

Merged
merged 4 commits into from
Feb 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@
- Added GroupNorm plugin which internally uses PyTorch aten::group_norm
- Replaced Tensor.ndim references with len(tensor.shape) to support older pytorch versions
- Added reduced precision documentation page
- Added converters for ``floordiv``, ``mod``, ``ne``, and ``torch.tensor`` operations
- Extended ``relu`` converter to support ``Tensor.relu`` operation
- Extended ``sigmoid`` converter to support ``Tensor.sigmoid`` operation
4 changes: 4 additions & 0 deletions torch2trt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .compare import *
from .div import *
from .expand import *
from .floordiv import *
from .getitem import *
from .identity import *
from .instance_norm import *
Expand All @@ -35,8 +36,10 @@
from .max_pool2d import *
from .mean import *
from .min import *
from .mod import *
from .mul import *
from .normalize import *
from .ne import *
from .narrow import *
from .pad import *
from .permute import *
Expand All @@ -52,6 +55,7 @@
from .sub import *
from .sum import *
from .tanh import *
from .tensor import *
from .transpose import *
from .unary import *
from .view import *
81 changes: 81 additions & 0 deletions torch2trt/converters/floordiv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.Tensor.__floordiv__')
@tensorrt_converter('torch.Tensor.__ifloordiv__')
@tensorrt_converter('torch.floor_divide')
def convert_floordiv(ctx):
input_a = ctx.method_args[0]
input_b = ctx.method_args[1]
output = ctx.method_return
input_a_trt, input_b_trt = add_missing_trt_tensors(ctx.network, [input_a, input_b])
input_a_trt, input_b_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt], len(output.shape) - 1)
# we can not use ElementWiseOperation.FLOOR_DIV directly because Torch truncate negative result toward 0
# but TensorRT FLOOR_DIV op toward -Inf
# sign = ab / |ab|
# floordiv result: sign * (|a| // |b|)
ab_layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.PROD)
abs_ab_layer = ctx.network.add_unary(ab_layer.get_output(0), trt.UnaryOperation.ABS)
sign_layer = ctx.network.add_elementwise(ab_layer.get_output(0), abs_ab_layer.get_output(0),
trt.ElementWiseOperation.DIV)
abs_a_layer = ctx.network.add_unary(input_a_trt, trt.UnaryOperation.ABS)
abs_b_layer = ctx.network.add_unary(input_b_trt, trt.UnaryOperation.ABS)
abs_floor_layer = ctx.network.add_elementwise(abs_a_layer.get_output(0), abs_b_layer.get_output(0),
trt.ElementWiseOperation.FLOOR_DIV)
out_layer = ctx.network.add_elementwise(sign_layer.get_output(0), abs_floor_layer.get_output(0),
trt.ElementWiseOperation.PROD)
output._trt = out_layer.get_output(0)


class FloorDiv(torch.nn.Module):
def __init__(self):
super(FloorDiv, self).__init__()

def forward(self, x, y):
return x // y


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20), (1, 3, 1, 20)])
def test_floordiv_op():
return FloorDiv()


class FloorDivAssign (torch.nn.Module):
def __init__(self):
super(FloorDivAssign, self).__init__()

def forward(self, x, y):
x //= y
return x


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20), (1, 3, 1, 20)])
def test_floordiv_op_assign():
return FloorDivAssign()


class FloorDivConst(torch.nn.Module):
def __init__(self):
super(FloorDivConst, self).__init__()

def forward(self, x):
return x // 2.


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20)])
def test_floordiv_op_const():
return FloorDivConst()


class TorchFloorDiv(torch.nn.Module):
def __init__(self):
super(TorchFloorDiv, self).__init__()

def forward(self, x, y):
return torch.floor_divide(x, y)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20), (1, 3, 1, 20)])
def test_floordiv_func():
return TorchFloorDiv()
99 changes: 99 additions & 0 deletions torch2trt/converters/mod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.fmod')
def convert_mod(ctx):
input_a = ctx.method_args[0]
input_b = ctx.method_args[1]
output = ctx.method_return
input_a_trt, input_b_trt = add_missing_trt_tensors(ctx.network, [input_a, input_b])
input_a_trt, input_b_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt], len(output.shape) - 1)
# we can not use ElementWiseOperation.FLOOR_DIV directly because Torch truncate negative result toward 0
# but TensorRT FLOOR_DIV op toward -Inf
# sign = ab / |ab|
# floordiv result: sign * (|a| // |b|)
ab_layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.PROD)
abs_ab_layer = ctx.network.add_unary(ab_layer.get_output(0), trt.UnaryOperation.ABS)
sign_layer = ctx.network.add_elementwise(ab_layer.get_output(0), abs_ab_layer.get_output(0),
trt.ElementWiseOperation.DIV)
abs_a_layer = ctx.network.add_unary(input_a_trt, trt.UnaryOperation.ABS)
abs_b_layer = ctx.network.add_unary(input_b_trt, trt.UnaryOperation.ABS)
abs_floor_layer = ctx.network.add_elementwise(abs_a_layer.get_output(0), abs_b_layer.get_output(0),
trt.ElementWiseOperation.FLOOR_DIV)
# a % b = a - (a//b) * b
floordiv_layer = ctx.network.add_elementwise(sign_layer.get_output(0), abs_floor_layer.get_output(0),
trt.ElementWiseOperation.PROD)
prod_layer = ctx.network.add_elementwise(floordiv_layer.get_output(0), input_b_trt, trt.ElementWiseOperation.PROD)
sub_layer = ctx.network.add_elementwise(input_a_trt, prod_layer.get_output(0), trt.ElementWiseOperation.SUB)
output._trt = sub_layer.get_output(0)


@tensorrt_converter('torch.Tensor.__mod__')
# we need separate converter for operator because for some reason Torch use truncation toward -Inf for this op.
# bug is filed: https://github.com/pytorch/pytorch/issues/52425
# but for now we have to convert model exactly
def convert_mod(ctx):
input_a = ctx.method_args[0]
input_b = ctx.method_args[1]
output = ctx.method_return
input_a_trt, input_b_trt = add_missing_trt_tensors(ctx.network, [input_a, input_b])
input_a_trt, input_b_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt], len(output.shape) - 1)
# a % b = a - (a//b) * b
floordiv_layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.FLOOR_DIV)
prod_layer = ctx.network.add_elementwise(floordiv_layer.get_output(0), input_b_trt, trt.ElementWiseOperation.PROD)
mod_layer = ctx.network.add_elementwise(input_a_trt, prod_layer.get_output(0), trt.ElementWiseOperation.SUB)
output._trt = mod_layer.get_output(0)


class Mod(torch.nn.Module):
def __init__(self):
super(Mod, self).__init__()

def forward(self, x, y):
return x % y


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20), (1, 3, 1, 20)])
def test_mod_op():
return Mod()


class ModAssign(torch.nn.Module):
def __init__(self):
super(ModAssign, self).__init__()

def forward(self, x, y):
x %= y
return x


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20), (1, 3, 1, 20)])
def test_mod_op_assign():
return ModAssign()


class ModConst(torch.nn.Module):
def __init__(self):
super(ModConst, self).__init__()

def forward(self, x):
return x % 2.


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20)])
def test_mod_op_const():
return ModConst()


class TorchMod(torch.nn.Module):
def __init__(self):
super(TorchMod, self).__init__()

def forward(self, x, y):
return torch.fmod(x, y)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20), (1, 3, 40, 20)])
def test_mod_func():
return TorchMod()
54 changes: 54 additions & 0 deletions torch2trt/converters/ne.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.ne')
@tensorrt_converter('torch.Tensor.__ne__')
def convert_ne(ctx):
input_a = ctx.method_args[0]
input_b = ctx.method_args[1]
output = ctx.method_return
input_a_trt, input_b_trt = add_missing_trt_tensors(ctx.network, [input_a, input_b])
input_a_trt, input_b_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt], len(output.shape) - 1)
layer_1 = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.EQUAL)
layer_2 = ctx.network.add_unary(layer_1.get_output(0), trt.UnaryOperation.NOT)
output._trt = layer_2.get_output(0)


class NotEqual(torch.nn.Module):
def __init__(self):
super(NotEqual, self).__init__()

def forward(self, x, y):
return x != y


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20), (1, 3, 1, 20)])
def test_ne_op():
return NotEqual()


class NotEqualConst(torch.nn.Module):
def __init__(self):
super(NotEqualConst, self).__init__()

def forward(self, x):
return x != 13.62


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20)])
def test_ne_op_const():
return NotEqualConst()


class TorchNotEqual(torch.nn.Module):
def __init__(self):
super(TorchNotEqual, self).__init__()

def forward(self, x, y):
return torch.ne(x, y)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20), (1, 3, 1, 20)])
def test_ne_torch():
return TorchNotEqual()
16 changes: 15 additions & 1 deletion torch2trt/converters/relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
@tensorrt_converter('torch.relu_')
@tensorrt_converter('torch.nn.functional.relu')
@tensorrt_converter('torch.nn.functional.relu_')
@tensorrt_converter('torch.Tensor.relu')
def convert_functional_relu(ctx):
ctx.method_args = (torch.nn.ReLU(),) + ctx.method_args
convert_relu(ctx)
Expand All @@ -32,4 +33,17 @@ def forward(self, x):

@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)])
def test_functional_relu_basic():
return FunctionalRelu()
return FunctionalRelu()


class TensorRelu(torch.nn.Module):
def __init__(self):
super(TensorRelu, self).__init__()

def forward(self, x):
return x.relu()


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20)])
def test_tensor_relu():
return TensorRelu()
16 changes: 15 additions & 1 deletion torch2trt/converters/sigmoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

@tensorrt_converter('torch.nn.functional.sigmoid')
@tensorrt_converter('torch.sigmoid')
@tensorrt_converter('torch.Tensor.sigmoid')
def convert_sigmoid(ctx):
input = ctx.method_args[0]
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
Expand All @@ -15,4 +16,17 @@ def convert_sigmoid(ctx):

@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)])
def test_sigmoid_basic():
return torch.nn.Sigmoid()
return torch.nn.Sigmoid()


class TensorSigmoid(torch.nn.Module):
def __init__(self):
super(TensorSigmoid, self).__init__()

def forward(self, x):
return x.sigmoid()


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 40, 20)])
def test_tensor_sigmoid():
return TensorSigmoid()
22 changes: 22 additions & 0 deletions torch2trt/converters/tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.tensor')
def convert_mod(ctx):
output = ctx.method_return
layer = ctx.network.add_constant(tuple(output.shape), output.detach().cpu().numpy() )
output._trt = layer.get_output(0)


class TorchTensor(torch.nn.Module):
def __init__(self):
super(TorchTensor, self).__init__()

def forward(self, x):
return x + torch.tensor([[1., 2., 3.], [4., 5., 6.]], device=torch.device('cuda'))


@add_module_test(torch.float32, torch.device('cuda'), [(1, 2, 3)])
def test_tensor_creation():
return TorchTensor()