Skip to content

Commit

Permalink
Reorg for converters in hardtanh(FX Converter Refactor [5/N]) <Target…
Browse files Browse the repository at this point in the history
…: converter_reorg_proto> (#1901)
  • Loading branch information
apbose authored and narendasan committed Jun 3, 2023
1 parent 7864865 commit ae564d7
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 12 deletions.
16 changes: 4 additions & 12 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3582,23 +3582,15 @@ def acc_ops_hardtanh(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]

if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"hardtanh received input {input_val} that is not part "
"of the TensorRT region!"
)

return activation.convert_activation(
return activation.hardtanh(
network,
target,
SourceIR.ACC,
name,
trt.ActivationType.CLIP,
input_val,
alpha=kwargs["min_val"],
beta=kwargs["max_val"],
kwargs["input"],
kwargs["min_val"],
kwargs["max_val"],
)


Expand Down
14 changes: 14 additions & 0 deletions py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,20 @@ def aten_ops_fmod(
return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name)


@tensorrt_converter(torch.ops.aten.hardtanh.default)
def aten_ops_hardtanh(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:

return activation.hardtanh(
network, target, SourceIR.ATEN, name, args[0], args[1], args[2]
)


@tensorrt_converter(torch.ops.aten.linear)
def aten_ops_linear(
network: TRTNetwork,
Expand Down
31 changes: 31 additions & 0 deletions py/torch_tensorrt/fx/converters/impl/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,37 @@ def convert_activation(
return layer.get_output(0)


def hardtanh(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
alpha: Optional[Any] = None,
beta: Optional[Any] = None,
):
operation_type = trt.ActivationType.CLIP

def hardtanh_dyn_range_fn(dyn_range):
def hardtanh_fn(x):
# TODO: Called torch.nn.functional.hardtanh
return torch.nn.functional.hardtanh(x)

return hardtanh_dyn_range_fn(dyn_range[0]), hardtanh_dyn_range_fn(dyn_range[1])

return convert_activation(
network,
target,
source_ir,
name,
operation_type,
input_val,
alpha,
beta,
dyn_range_fn=hardtanh_dyn_range_fn,
)


def relu(
network: TRTNetwork,
target: Target,
Expand Down
15 changes: 15 additions & 0 deletions py/torch_tensorrt/fx/converters/nn_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,18 @@ def sigmoid(network, submod, args, kwargs, layer_name):
name=layer_name,
input_val=kwargs["input"],
)


@tensorrt_converter(torch.nn.functional.hardtanh)
@tensorrt_converter(torch.nn.modules.activation.Hardtanh)
def hardtanh(network, submod, args, kwargs, layer_name):
# args/kwargs should have already been normalized to kwargs
assert len(args) == 0

return activation.hardtanh(
network=network,
target="torch.nn.modules.activation.Hardtanh",
source_ir=SourceIR.NN,
name=layer_name,
input_val=kwargs["input"],
)
53 changes: 53 additions & 0 deletions py/torch_tensorrt/fx/test/converters/aten_op/test_hardtanh_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec


class TestHardTanHConverter(DispatchTestCase):
def test_hardtanh(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.hardtanh(x)

inputs = [torch.randn(1, 10)]
self.run_test(
TestModule(), inputs, expected_ops={torch.ops.aten.hardtanh.default}
)

def test_hardtanh_with_dynamic_shape(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.hardtanh(x)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
),
]
self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={torch.ops.aten.hardtanh.default}
)

def test_hardtanh_with_dynamic_shape_four_dimensions(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.hardtanh(x)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
),
]

self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={torch.ops.aten.hardtanh.default}
)


if __name__ == "__main__":
run_tests()

0 comments on commit ae564d7

Please sign in to comment.