Skip to content

Commit

Permalink
Koenvandesande remove duplicate filenames (#448)
Browse files Browse the repository at this point in the history
* Remove duplicate filenames which do not work on Windows by merging files

* Fix

* relu tests

Co-authored-by: Koen van de Sande <koen@keplervision.eu>
  • Loading branch information
jaybdub and koenvandesande authored Nov 17, 2020
1 parent d1fa6f9 commit adccbf1
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 55 deletions.
11 changes: 0 additions & 11 deletions torch2trt/converters/Identity.py

This file was deleted.

11 changes: 0 additions & 11 deletions torch2trt/converters/ReLU.py

This file was deleted.

23 changes: 0 additions & 23 deletions torch2trt/converters/ReLU6.py

This file was deleted.

3 changes: 0 additions & 3 deletions torch2trt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@
from .Conv2d import *
from .ConvTranspose import *
from .ConvTranspose2d import *
from .Identity import *
from .Linear import *
from .LogSoftmax import *
from .ReLU import *
from .ReLU6 import *
from .activation import *
from .adaptive_avg_pool2d import *
from .adaptive_max_pool2d import *
Expand Down
12 changes: 11 additions & 1 deletion torch2trt/converters/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,18 @@
@tensorrt_converter('torch.nn.functional.dropout')
@tensorrt_converter('torch.nn.functional.dropout2d')
@tensorrt_converter('torch.nn.functional.dropout3d')
def convert_identity(ctx):
def convert_functional_identity(ctx):
input = ctx.method_args[0]
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
output = ctx.method_return
output._trt = input_trt


@tensorrt_converter('torch.nn.Dropout.forward')
@tensorrt_converter('torch.nn.Dropout2d.forward')
@tensorrt_converter('torch.nn.Dropout3d.forward')
def convert_identity(ctx):
input = ctx.method_args[1]
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
output = ctx.method_return
output._trt = input_trt
30 changes: 27 additions & 3 deletions torch2trt/converters/relu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,35 @@
from torch2trt.torch2trt import *
from .ReLU import *
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.relu')
@tensorrt_converter('torch.relu_')
@tensorrt_converter('torch.nn.functional.relu')
@tensorrt_converter('torch.nn.functional.relu_')
def convert_relu(ctx):
def convert_functional_relu(ctx):
ctx.method_args = (torch.nn.ReLU(),) + ctx.method_args
convert_ReLU(ctx)
convert_relu(ctx)


@tensorrt_converter('torch.nn.ReLU.forward')
def convert_relu(ctx):
input = ctx.method_args[1]
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
output = ctx.method_return
layer = ctx.network.add_activation(
input=input_trt, type=trt.ActivationType.RELU)
output._trt = layer.get_output(0)

@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)])
def test_relu_basic():
return torch.nn.ReLU()


class FunctionalRelu(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.relu(x)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)])
def test_functional_relu_basic():
return FunctionalRelu()
38 changes: 35 additions & 3 deletions torch2trt/converters/relu6.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,40 @@
from torch2trt.torch2trt import *
from .ReLU6 import *
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.nn.functional.relu6')
def convert_relu6(ctx):
def convert_functional_relu6(ctx):
ctx.method_args = (torch.nn.ReLU6(),) + ctx.method_args
convert_ReLU6(ctx)
convert_relu6(ctx)


@tensorrt_converter('torch.nn.ReLU6.forward')
def convert_relu6(ctx):
input = ctx.method_args[1]
output = ctx.method_return

input_a_trt, input_b_trt = add_missing_trt_tensors(ctx.network, [input, 6])
input_a_trt, input_b_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt], len(output.shape) - 1)

layer = ctx.network.add_activation(
input=input_a_trt, type=trt.ActivationType.RELU)
layer = ctx.network.add_elementwise(
layer.get_output(0), input_b_trt, trt.ElementWiseOperation.MIN)

output._trt = layer.get_output(0)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)])
def test_relu6_basic():
return torch.nn.ReLU6()


class FunctionalRelu6(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.relu6(x)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)])
def test_functional_relu6_basic():
return FunctionalRelu6()

0 comments on commit adccbf1

Please sign in to comment.