diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 8acf7130d7..e393645645 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2572,3 +2572,27 @@ def aten_ops_copy( src.dtype, force_layer=True, ) + + +@dynamo_tensorrt_converter(torch.ops.aten.remainder.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.remainder.Tensor) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_remainder( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.remainder( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index a69fca944b..d0f6d29482 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -178,6 +178,41 @@ def fmod( return sub_value +def remainder( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + other: TRTTensor, +) -> TRTTensor: + fmod1_value = fmod( + ctx, + target, + source_ir, + f"{name}_fmod1", + input, + other, + ) + added_value = add( + ctx, + target, + source_ir, + f"{name}_add", + fmod1_value, + other, + ) + fmod2_value = fmod( + ctx, + target, + source_ir, + f"{name}_fmod2", + added_value, + other, + ) + return fmod2_value + + def clamp( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_remainder_aten.py b/tests/py/dynamo/conversion/test_remainder_aten.py new file mode 100644 index 0000000000..8b3035e90f --- /dev/null +++ b/tests/py/dynamo/conversion/test_remainder_aten.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestRemainderConverter(DispatchTestCase): + @parameterized.expand( + [ + ("1d", (5,), 3), + ("2d", (2, 1), 1.0), + ("3d", (2, 1, 2), 2), + ] + ) + def test_remainder_scalar(self, _, shape, scalar): + class Remainder(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.remainder.Scalar(lhs_val, scalar) + + inputs = [torch.randn(shape)] + self.run_test( + Remainder(), + inputs, + ) + + def test_remainder_scalar_int(self, scalar=3): + class Remainder(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.remainder.Scalar(lhs_val, scalar) + + inputs = [torch.tensor([0, 1, 2, 3, 4, -1, -2, -3, -4], dtype=torch.float32)] + self.run_test( + Remainder(), + inputs, + ) + + @parameterized.expand( + [ + ("1d", (5,)), + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_remainder_tensor(self, _, shape): + class Remainder(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.remainder.Tensor(lhs_val, rhs_val) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + Remainder(), + inputs, + ) + + +if __name__ == "__main__": + run_tests()