From e37812dd9a914c2645905a234b1ae6501c929c6f Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Tue, 9 Jul 2024 19:05:23 +0900 Subject: [PATCH 1/3] feat: dynamic shape support for aten.select.int --- .../dynamo/conversion/impl/select.py | 16 +----- .../py/dynamo/conversion/test_select_aten.py | 57 ++++++++++++++----- 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 6d9a86f89b..8a5a0baead 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -47,26 +47,16 @@ def select( if dynamic_shape: # Check whether slice target dim is dynamic shape dim assert input.shape[dim] != -1, "Can't select on negative shape dimension!" - index = index if index >= input.shape[dim]: raise RuntimeError( f"cannot have index greater than the dimension length! {input.shape[dim]}" ) - output_shape = list(input.shape) - output_shape[dim] = 1 - if dynamic_shape > 0: - output_shape = get_shape_with_dynamic_shape( - ctx, target, source_ir, name, output_shape, input - ) + index_value = np.array(index, dtype=np.int32) - indices_tensor = ctx.net.add_constant( - index_value.shape, to_numpy(index_value) - ).get_output(0) + indices_tensor = ctx.net.add_constant(index_value.shape, index_value).get_output(0) layer = ctx.net.add_gather(input, indices_tensor, dim) - out = layer.get_output(0) - if len(out.shape) != 1: - layer = ctx.net.add_shuffle(out) + return layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_select_aten.py b/tests/py/dynamo/conversion/test_select_aten.py index 4a9f0666a9..9beda81a42 100644 --- a/tests/py/dynamo/conversion/test_select_aten.py +++ b/tests/py/dynamo/conversion/test_select_aten.py @@ -1,4 +1,5 @@ import torch +import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input @@ -13,7 +14,7 @@ class TestSelectConverterOne(DispatchTestCase): ] ) def test_select(self, _, dim, index): - class TestModule(torch.nn.Module): + class select(nn.Module): def __init__(self): super().__init__() @@ -22,7 +23,7 @@ def forward(self, input): input = [torch.randn(1, 2)] self.run_test( - TestModule(), + select(), input, ) @@ -34,7 +35,7 @@ class TestSelectConverterTwo(DispatchTestCase): ] ) def test_select(self, _, dim, index): - class TestModule(torch.nn.Module): + class select(nn.Module): def __init__(self): super().__init__() @@ -43,33 +44,63 @@ def forward(self, input): input = [torch.randn(4, 4, 4, 4)] self.run_test( - TestModule(), + select(), input, ) -class TestSelectConverterWithDynamicShape(DispatchTestCase): +class TestSelectConverterDynamicShape(DispatchTestCase): @parameterized.expand( [ - ("select_dim_index", 1, 0), + ( + "select_dim_index", + (1, 3, 3), + (2, 3, 3), + (3, 3, 3), + torch.int32, + 1, + 0, + ), + ( + "select_dim_index", + (1, 1, 3), + (2, 2, 3), + (3, 3, 3), + torch.float, + 2, + 0, + ), + ( + "select_dim_index", + (3, 1, 1), + (3, 2, 2), + (3, 3, 3), + torch.float, + 0, + 2, + ), ] ) - def test_select_with_dynamic_shape(self, _, dim, index): - class TestModule(torch.nn.Module): + def test_dynamic_shape_select( + self, _, min_shape, opt_shape, max_shape, type, dim, index + ): + class select(nn.Module): def __init__(self): super().__init__() def forward(self, input): return torch.ops.aten.select.int(input, dim, index) - input_spec = [ + input_specs = [ Input( - shape=(-1, 3, 3), - dtype=torch.float32, - shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))], + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, ), ] - self.run_test_with_dynamic_shape(TestModule(), input_spec) + + self.run_test_with_dynamic_shape(select(), input_specs) if __name__ == "__main__": From 7f86c63afb605df0f96986b892f295c393bebe5a Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Wed, 10 Jul 2024 11:40:27 +0900 Subject: [PATCH 2/3] chore: mark dynamic support for select.int --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index c50fa57400..92f4b7e18e 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -794,7 +794,7 @@ def aten_ops_scatter( ) -@dynamo_tensorrt_converter(torch.ops.aten.select.int) +@dynamo_tensorrt_converter(torch.ops.aten.select.int, supports_dynamic_shapes=True) def aten_ops_select( ctx: ConversionContext, target: Target, From 20c547b4dd11eebc3da17bf3766739b89e825db6 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Sat, 20 Jul 2024 14:52:22 +0900 Subject: [PATCH 3/3] feat: fully dynamic support for aten.select.int --- .../dynamo/conversion/impl/select.py | 24 ++++------ .../py/dynamo/conversion/test_select_aten.py | 47 ++++++++++--------- 2 files changed, 34 insertions(+), 37 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 8a5a0baead..6653e9e1a5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Sequence, Union, cast +from typing import Optional, Sequence, Union import numpy as np import tensorrt as trt @@ -21,7 +21,7 @@ has_dynamic_shape, set_layer_name, ) -from torch_tensorrt.fx.types import Shape, TRTTensor +from torch_tensorrt.fx.types import TRTTensor _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -32,8 +32,8 @@ def select( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - dim: Shape, - index: Shape, + dim: int, + index: int, ) -> TRTTensor: if not isinstance(input, TRTTensor): raise RuntimeError( @@ -42,19 +42,11 @@ def select( ) ranks = len(input.shape) - dim = get_positive_dim(cast(int, dim), ranks) - dynamic_shape = has_dynamic_shape(input.shape) - if dynamic_shape: - # Check whether slice target dim is dynamic shape dim - assert input.shape[dim] != -1, "Can't select on negative shape dimension!" - - if index >= input.shape[dim]: - raise RuntimeError( - f"cannot have index greater than the dimension length! {input.shape[dim]}" - ) + dim = get_positive_dim(dim, ranks) - index_value = np.array(index, dtype=np.int32) - indices_tensor = ctx.net.add_constant(index_value.shape, index_value).get_output(0) + indices_tensor = get_trt_tensor( + ctx, np.array(index, dtype=np.int32), f"{name}_indices_tensor" + ) layer = ctx.net.add_gather(input, indices_tensor, dim) return layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_select_aten.py b/tests/py/dynamo/conversion/test_select_aten.py index 9beda81a42..cce3fa0a30 100644 --- a/tests/py/dynamo/conversion/test_select_aten.py +++ b/tests/py/dynamo/conversion/test_select_aten.py @@ -10,10 +10,10 @@ class TestSelectConverterOne(DispatchTestCase): @parameterized.expand( [ - ("select_dim_index", 1, 0), + ("dim_index", 1, 0), ] ) - def test_select(self, _, dim, index): + def test_select_2d(self, _, dim, index): class select(nn.Module): def __init__(self): super().__init__() @@ -27,14 +27,12 @@ def forward(self, input): input, ) - -class TestSelectConverterTwo(DispatchTestCase): @parameterized.expand( [ - ("select_dim_index", 1, 0), + ("dim_index", 1, 0), ] ) - def test_select(self, _, dim, index): + def test_select_4d(self, _, dim, index): class select(nn.Module): def __init__(self): super().__init__() @@ -48,36 +46,43 @@ def forward(self, input): input, ) - -class TestSelectConverterDynamicShape(DispatchTestCase): @parameterized.expand( [ ( - "select_dim_index", - (1, 3, 3), - (2, 3, 3), + "partial_dynamic_static_dim", + (1, 1, 3), + (2, 2, 3), (3, 3, 3), - torch.int32, - 1, + torch.float, + 2, 0, ), ( - "select_dim_index", + "partial_dynamic_dynamic_dim", (1, 1, 3), (2, 2, 3), (3, 3, 3), torch.float, - 2, - 0, + 1, + 1, ), ( - "select_dim_index", - (3, 1, 1), - (3, 2, 2), + "fully_dynamic", + (1, 1, 1), + (2, 2, 2), (3, 3, 3), torch.float, - 0, - 2, + 1, + 1, + ), + ( + "fully_dynamic_neg_dim", + (1, 1, 1), + (2, 2, 2), + (3, 3, 3), + torch.float, + -1, + 1, ), ] )