From 6abe7cee410dd56f4b55c482462084ae83f9c717 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 19 Mar 2024 14:06:34 -0700 Subject: [PATCH] empty_permute decomposition --- .../dynamo/lowering/_decomposition_groups.py | 1 + .../dynamo/lowering/_decompositions.py | 12 ++++ .../py/dynamo/lowering/test_decompositions.py | 65 +++++++++++++++++++ 3 files changed, 78 insertions(+) diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index de791851db..98c25a1f54 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -37,6 +37,7 @@ aten.elu_backward, aten._embedding_bag, aten.embedding_dense_backward, + aten.empty_like, aten._euclidean_dist.default, aten.expand_as, aten.eye, diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 981c80f9fa..9ba7ec964b 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -162,6 +162,18 @@ def var_decomposition( return variance +@register_torch_trt_decomposition( + torch.ops.aten.empty_permuted.default, registry=TORCH_TRT_DECOMPOSITIONS +) +def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor: + empty_size = args[0] + empty_permute = args[1] + perm = [0] * len(empty_size) + for permute_index, permute_element in enumerate(empty_permute): + perm[permute_element] = permute_index + return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm) + + def get_decompositions( enable_experimental_decompositions: bool = False, ) -> Dict[OpOverload, Callable[[Any], Any]]: diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 84e8d11585..457e9e2e81 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -420,6 +420,71 @@ def forward(self, x): f"MaxPool3d TRT outputs don't match with the original model.", ) + def test_lowering_empty_like_module(self): + class emptyLike(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x): + c = torch.ops.aten.add(x, x) + y = torch.ops.aten.empty_like.default(c) + d = y + c + return d + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = {torch.ops.aten.add.Tensor} + unexpected_ops = { + torch.ops.aten.empty_like.default, + torch.ops.aten.empty_permuted.default, + } + + inputs = [torch.zeros(3, 2).cuda()] + + fx_graph = torch.fx.symbolic_trace(emptyLike()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + truncate_long_and_double=True, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"Select_scatter TRT outputs don't match with the original model.", + ) + if __name__ == "__main__": run_tests()