From b68ccacc47fb29dbcbac5951aec83c8ff838cecb Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Sat, 27 Jul 2024 00:31:01 +0900 Subject: [PATCH 1/2] chore: dynamic shape support for flip ops --- .../dynamo/conversion/aten_ops_converters.py | 2 +- .../dynamo/conversion/impl/slice/ops.py | 31 +++++++++-- tests/py/dynamo/conversion/test_flip_aten.py | 53 +++++++++++++++++++ 3 files changed, 82 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index bc01c52db9..8841a50caa 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -3259,7 +3259,7 @@ def aten_ops_pdist( ) -@dynamo_tensorrt_converter(torch.ops.aten.flip.default) +@dynamo_tensorrt_converter(torch.ops.aten.flip.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 776e2bec8e..7d4d607949 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -446,13 +446,22 @@ def flip( output_shape = list(input.shape) stride_slice = [] + dynamic_shape = has_dynamic_shape(input.shape) + shape = input.shape rank = len(shape) dims = get_positive_dim(dims, rank) for i in range(rank): if i in dims: - start_slice.append(shape[i] - 1) + if shape[i] == DYNAMIC_DIM: + dim = get_shape(ctx, target, source_ir, f"{name}_shape_dim", input, i) + last_element_index = impl.elementwise.sub( + ctx, target, source_ir, f"{name}_sub", dim, 1 + ) + start_slice.append(last_element_index) + else: + start_slice.append(shape[i] - 1) stride_slice.append(-1) else: start_slice.append(0) @@ -460,10 +469,26 @@ def flip( layer = ctx.net.add_slice( input, - start=start_slice, - shape=output_shape, + start=[] if dynamic_shape else start_slice, + shape=[] if dynamic_shape else output_shape, stride=stride_slice, ) + if dynamic_shape: + output_shape = get_shape_with_dynamic_shape( + ctx, target, source_ir, f"{name}_shape", output_shape, input + ) + + start_slice_tensor = cat( + ctx, + target, + source_ir, + f"{name}_start_slice_concat", + start_slice, + 0, + ) + layer.set_input(1, start_slice_tensor) + layer.set_input(2, output_shape) + set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_flip_aten.py b/tests/py/dynamo/conversion/test_flip_aten.py index aa4a2cd374..489aed6030 100644 --- a/tests/py/dynamo/conversion/test_flip_aten.py +++ b/tests/py/dynamo/conversion/test_flip_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -33,5 +34,57 @@ def forward(self, x): self.run_test(Flip(), inputs) +class TestFlipConverterDynamic(DispatchTestCase): + @parameterized.expand( + [ + ( + "3d_dynamic", + (2, 1, 1), + (2, 2, 1), + (3, 2, 4), + [2, 1, 0], + ), + ( + "3d_dynamic_negative_dim", + (2, 1, 1), + (2, 2, 1), + (3, 2, 4), + [-1, 1], + ), + ( + "4d_dynamic_static_dim", + (3, 1, 1, 1), + (3, 2, 1, 2), + (3, 2, 4, 5), + [0, 2, 3], + ), + ( + "3d_dynamic_no_dim", + (2, 1, 1), + (2, 2, 1), + (3, 2, 4), + [], + ), + ] + ) + def test_flip_dynamic(self, _, min_shape, opt_shape, max_shape, dims): + class Flip(nn.Module): + def forward(self, x): + return torch.ops.aten.flip.default(x, dims) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=torch.float, + ), + ] + self.run_test_with_dynamic_shape( + Flip(), + input_specs, + ) + + if __name__ == "__main__": run_tests() From 96977014ebb53f681a885982cdb57d78fef9aaf4 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Wed, 31 Jul 2024 23:23:32 +0900 Subject: [PATCH 2/2] chore: suffix of name in loop --- py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 7d4d607949..cbd55d9d55 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -455,9 +455,11 @@ def flip( for i in range(rank): if i in dims: if shape[i] == DYNAMIC_DIM: - dim = get_shape(ctx, target, source_ir, f"{name}_shape_dim", input, i) + dim = get_shape( + ctx, target, source_ir, f"{name}_shape_dim_{i}", input, i + ) last_element_index = impl.elementwise.sub( - ctx, target, source_ir, f"{name}_sub", dim, 1 + ctx, target, source_ir, f"{name}_sub_{i}", dim, 1 ) start_slice.append(last_element_index) else: