From 6d739eefadffcfc1d4643fa52ab8d22c2437a0f8 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Sun, 21 Jul 2024 15:56:25 +0900 Subject: [PATCH 1/2] chore: dynamic shape support for any/sort/trunc ops --- .../dynamo/conversion/aten_ops_converters.py | 21 ++- .../dynamo/conversion/impl/topk.py | 38 ++++- py/torch_tensorrt/dynamo/utils.py | 3 +- tests/py/dynamo/conversion/test_any.py | 148 ++++++++++++++++++ tests/py/dynamo/conversion/test_sort_aten.py | 50 ++++++ tests/py/dynamo/conversion/test_trunc_aten.py | 44 ++++++ 6 files changed, 294 insertions(+), 10 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 3da1b09fba..a135f5c688 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -19,6 +19,7 @@ get_positive_dim, is_only_operator_on_placeholder, ) +from torch_tensorrt.dynamo.utils import TRT_TOPK_MAX_ELEMENT from torch_tensorrt.fx.types import TRTTensor _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -2648,6 +2649,10 @@ def topk_validator(node: Node) -> bool: def sort_validator(node: Node) -> bool: + # if meta data is not available(e.g. dynamic shape), validate k value during runtime. + if not node.args[0].meta: + return True + shape = node.args[0].meta.get("tensor_meta").shape dim = node.args[1] dim = get_positive_dim(dim, len(shape)) @@ -2656,9 +2661,9 @@ def sort_validator(node: Node) -> bool: def topk_sort_validator(k: int) -> bool: - if k > 3840: + if k > TRT_TOPK_MAX_ELEMENT: _LOGGER.debug( - f"Currently only topk values up to 3840 are supported, got k={k}." + f"Currently only topk values up to {TRT_TOPK_MAX_ELEMENT} are supported, got k={k}." ) return False return True @@ -3103,7 +3108,9 @@ def aten_ops_topk( @dynamo_tensorrt_converter( - torch.ops.aten.sort.default, capability_validator=sort_validator + torch.ops.aten.sort.default, + capability_validator=sort_validator, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { @@ -3128,7 +3135,7 @@ def aten_ops_sort( ) -@dynamo_tensorrt_converter(torch.ops.aten.trunc.default) +@dynamo_tensorrt_converter(torch.ops.aten.trunc.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), @@ -3204,9 +3211,9 @@ def aten_ops_remainder( ) -@dynamo_tensorrt_converter(torch.ops.aten.any.default) -@dynamo_tensorrt_converter(torch.ops.aten.any.dim) -@dynamo_tensorrt_converter(torch.ops.aten.any.dims) +@dynamo_tensorrt_converter(torch.ops.aten.any.default, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.any.dim, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.any.dims, supports_dynamic_shapes=True) def aten_ops_any( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/topk.py b/py/torch_tensorrt/dynamo/conversion/impl/topk.py index 78dd25d5a1..a779682419 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/topk.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/topk.py @@ -10,9 +10,12 @@ flatten_dims, get_axes_for_reduce_op, get_positive_dim, + set_layer_name, ) -from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTTensor +from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import le +from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape +from torch_tensorrt.dynamo.types import TRTTensor +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, TRT_TOPK_MAX_ELEMENT def argmax_argmin( @@ -155,6 +158,37 @@ def topk( k, get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))), ) + if k == DYNAMIC_DIM: + output_shape = get_shape_with_dynamic_shape( + ctx, target, source_ir, name, input.shape, input + ) + layer = ctx.net.add_slice( + output_shape, + start=[dim], + shape=[1], + stride=[1], + ) + set_layer_name(layer, target, name) + + # Get scalar tensor from 1d tensor + shuffle_layer = ctx.net.add_shuffle(layer.get_output(0)) + shuffle_layer.reshape_dims = trt.Dims() + set_layer_name(shuffle_layer, target, name, source_ir) + + cond = le( + ctx, + target, + source_ir, + f"{name}_k_cond", + shuffle_layer.get_output(0), + TRT_TOPK_MAX_ELEMENT, + ) + ctx.net.add_assertion( + cond, + message=f"Currently only topk values up to {TRT_TOPK_MAX_ELEMENT} are supported", + ) + + topk_layer.set_input(1, shuffle_layer.get_output(0)) # TensorRT ITopKLayer does not have a sorted flag, it is always returning the sorted topk elements # so here no matter sorted is True or False the returned the topk Tensor object is always sorted set_layer_name(topk_layer, target, name, source_ir) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index acfb2b0094..b556dd0756 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, Optional, Sequence, Union import numpy as np +import tensorrt as trt import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype @@ -13,7 +14,6 @@ from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._settings import CompilationSettings -import tensorrt as trt from packaging import version from .types import TRTDataType @@ -22,6 +22,7 @@ COSINE_THRESHOLD = 0.99 DYNAMIC_DIM = -1 +TRT_TOPK_MAX_ELEMENT = 3840 class Frameworks(Enum): diff --git a/tests/py/dynamo/conversion/test_any.py b/tests/py/dynamo/conversion/test_any.py index 29522145da..1d1fc634ef 100644 --- a/tests/py/dynamo/conversion/test_any.py +++ b/tests/py/dynamo/conversion/test_any.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -184,5 +185,152 @@ def forward(self, x): ) +class TestAnyConverterDynamic(DispatchTestCase): + @parameterized.expand( + [ + ( + "3d_dynamic_float", + (2, 1, 1), + (2, 2, 1), + (3, 2, 4), + torch.float, + ), + ( + "2d_dynamic_int32", + (2, 2), + (2, 2), + (3, 2), + torch.int32, + ), + ( + "4d_dynamic_bool", + (1, 2, 1, 1), + (2, 2, 2, 2), + (2, 2, 4, 3), + torch.bool, + ), + ] + ) + def test_any_dynamic(self, _, min_shape, opt_shape, max_shape, type): + class Any(nn.Module): + def forward(self, x): + return torch.ops.aten.any.default(x) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, + ), + ] + self.run_test_with_dynamic_shape( + Any(), + input_specs, + ) + + @parameterized.expand( + [ + ( + "3d_dynamic_dim_float", + (2, 1, 1), + (2, 2, 1), + (3, 2, 4), + torch.float, + 2, + True, + ), + ( + "4d_dynamic_dim_int32", + (1, 1, 4, 1), + (2, 2, 4, 2), + (2, 4, 4, 3), + torch.int32, + -2, + False, + ), + ( + "3d_dynamic_dim_bool", + (2, 1, 1), + (2, 2, 1), + (3, 2, 4), + torch.bool, + 0, + True, + ), + ] + ) + def test_any_dynamic_dim( + self, _, min_shape, opt_shape, max_shape, type, dim, keep_dims + ): + class AnyDim(nn.Module): + def forward(self, x): + return torch.ops.aten.any.dim(x, dim, keep_dims) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, + ), + ] + self.run_test_with_dynamic_shape( + AnyDim(), + input_specs, + ) + + @parameterized.expand( + [ + ( + "3d_dynamic_dims_float", + (2, 1, 1), + (2, 2, 1), + (3, 2, 4), + torch.float, + [1, 2], + True, + ), + ( + "4d_dynamic_dims_int32", + (1, 1, 4, 1), + (2, 2, 4, 2), + (2, 4, 4, 3), + torch.int32, + [2, -1], + False, + ), + ( + "3d_dynamic_dims_bool", + (1, 4, 1), + (2, 4, 2), + (4, 4, 3), + torch.bool, + [0, 1, 2], + False, + ), + ] + ) + def test_any_dynamic_dims( + self, _, min_shape, opt_shape, max_shape, type, dims, keep_dims + ): + class AnyDims(nn.Module): + def forward(self, x): + return torch.ops.aten.any.dims(x, dims, keep_dims) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, + ), + ] + self.run_test_with_dynamic_shape( + AnyDims(), + input_specs, + ) + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_sort_aten.py b/tests/py/dynamo/conversion/test_sort_aten.py index 8382da0047..3369d4c4b8 100644 --- a/tests/py/dynamo/conversion/test_sort_aten.py +++ b/tests/py/dynamo/conversion/test_sort_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -32,5 +33,54 @@ def forward(self, x): ) +class TestSortConverterDynamic(DispatchTestCase): + @parameterized.expand( + [ + ( + "3d_dynamic_descending", + (2, 2, 1), + (2, 2, 1), + (3, 2, 4), + 0, + True, + ), + ( + "4d_dynamic_ascending", + (2, 2, 1, 1), + (2, 2, 1, 2), + (3, 3, 2, 4), + 3, + False, + ), + ( + "4d_dynamic_descending_neg_dim", + (2, 2, 1, 1), + (2, 2, 1, 2), + (3, 3, 2, 4), + -3, + True, + ), + ] + ) + def test_sort_dynamic(self, _, min_shape, opt_shape, max_shape, dim, descending): + class Sort(nn.Module): + def forward(self, x): + return torch.ops.aten.sort.default(x, dim, descending) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=torch.float, + ), + ] + self.run_test_with_dynamic_shape( + Sort(), + input_specs, + output_dtypes=[torch.float, torch.int64], + ) + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_trunc_aten.py b/tests/py/dynamo/conversion/test_trunc_aten.py index 979ced17e2..211ddbf9d1 100644 --- a/tests/py/dynamo/conversion/test_trunc_aten.py +++ b/tests/py/dynamo/conversion/test_trunc_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -48,5 +49,48 @@ def forward(self, input): ) +class TestTruncConverterDynamic(DispatchTestCase): + @parameterized.expand( + [ + ( + "3d_dynamic_int32", + (1, 1, 1), + (2, 2, 2), + (3, 4, 5), + torch.int32, + False, + ), + ( + "3d_dynamic_float32", + (2, 1, 1), + (2, 2, 2), + (2, 4, 5), + torch.float32, + True, + ), + ] + ) + def test_trunc_dynamic( + self, _, min_shape, opt_shape, max_shape, type, enable_passes + ): + class Trunc(nn.Module): + def forward(self, input): + return torch.ops.aten.trunc.default(input) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, + ), + ] + self.run_test_with_dynamic_shape( + Trunc(), + input_specs, + enable_passes=enable_passes, + ) + + if __name__ == "__main__": run_tests() From e1491b0e9e9cb3b2254358f28c0356323e921008 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Mon, 29 Jul 2024 08:54:43 +0900 Subject: [PATCH 2/2] chore: Cannot support dynamic k in topk --- .../dynamo/conversion/aten_ops_converters.py | 16 ++++---- .../dynamo/conversion/impl/topk.py | 38 +++---------------- py/torch_tensorrt/dynamo/utils.py | 3 +- tests/py/dynamo/conversion/test_sort_aten.py | 15 ++++---- 4 files changed, 22 insertions(+), 50 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index a135f5c688..78a0e5b43b 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -19,7 +19,6 @@ get_positive_dim, is_only_operator_on_placeholder, ) -from torch_tensorrt.dynamo.utils import TRT_TOPK_MAX_ELEMENT from torch_tensorrt.fx.types import TRTTensor _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -2649,21 +2648,22 @@ def topk_validator(node: Node) -> bool: def sort_validator(node: Node) -> bool: - # if meta data is not available(e.g. dynamic shape), validate k value during runtime. - if not node.args[0].meta: - return True - - shape = node.args[0].meta.get("tensor_meta").shape + meta_data = node.args[0].meta.get("tensor_meta") + if meta_data is None: + return False + shape = meta_data.shape dim = node.args[1] dim = get_positive_dim(dim, len(shape)) k = shape[dim] + if not isinstance(k, int): + return False return topk_sort_validator(k) def topk_sort_validator(k: int) -> bool: - if k > TRT_TOPK_MAX_ELEMENT: + if k > 3840: _LOGGER.debug( - f"Currently only topk values up to {TRT_TOPK_MAX_ELEMENT} are supported, got k={k}." + f"Currently only topk values up to 3840 are supported, got k={k}." ) return False return True diff --git a/py/torch_tensorrt/dynamo/conversion/impl/topk.py b/py/torch_tensorrt/dynamo/conversion/impl/topk.py index a779682419..007f248af1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/topk.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/topk.py @@ -12,10 +12,8 @@ get_positive_dim, set_layer_name, ) -from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import le -from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape from torch_tensorrt.dynamo.types import TRTTensor -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, TRT_TOPK_MAX_ELEMENT +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM def argmax_argmin( @@ -158,40 +156,14 @@ def topk( k, get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))), ) - if k == DYNAMIC_DIM: - output_shape = get_shape_with_dynamic_shape( - ctx, target, source_ir, name, input.shape, input - ) - layer = ctx.net.add_slice( - output_shape, - start=[dim], - shape=[1], - stride=[1], - ) - set_layer_name(layer, target, name) - # Get scalar tensor from 1d tensor - shuffle_layer = ctx.net.add_shuffle(layer.get_output(0)) - shuffle_layer.reshape_dims = trt.Dims() - set_layer_name(shuffle_layer, target, name, source_ir) + # topk layer supports dynamic k value but we cannot dertermin supported dynamic topk value at + # compile time. + assert k != DYNAMIC_DIM, "k value cannot be dynamic!" - cond = le( - ctx, - target, - source_ir, - f"{name}_k_cond", - shuffle_layer.get_output(0), - TRT_TOPK_MAX_ELEMENT, - ) - ctx.net.add_assertion( - cond, - message=f"Currently only topk values up to {TRT_TOPK_MAX_ELEMENT} are supported", - ) - - topk_layer.set_input(1, shuffle_layer.get_output(0)) # TensorRT ITopKLayer does not have a sorted flag, it is always returning the sorted topk elements # so here no matter sorted is True or False the returned the topk Tensor object is always sorted - set_layer_name(topk_layer, target, name, source_ir) + set_layer_name(topk_layer, target, f"{name}_topk", source_ir) if return_indices: return topk_layer.get_output(0), topk_layer.get_output(1) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index b556dd0756..acfb2b0094 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -6,7 +6,6 @@ from typing import Any, Callable, Dict, Optional, Sequence, Union import numpy as np -import tensorrt as trt import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype @@ -14,6 +13,7 @@ from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._settings import CompilationSettings +import tensorrt as trt from packaging import version from .types import TRTDataType @@ -22,7 +22,6 @@ COSINE_THRESHOLD = 0.99 DYNAMIC_DIM = -1 -TRT_TOPK_MAX_ELEMENT = 3840 class Frameworks(Enum): diff --git a/tests/py/dynamo/conversion/test_sort_aten.py b/tests/py/dynamo/conversion/test_sort_aten.py index 3369d4c4b8..5f1258c6ac 100644 --- a/tests/py/dynamo/conversion/test_sort_aten.py +++ b/tests/py/dynamo/conversion/test_sort_aten.py @@ -38,24 +38,24 @@ class TestSortConverterDynamic(DispatchTestCase): [ ( "3d_dynamic_descending", - (2, 2, 1), - (2, 2, 1), + (2, 1, 4), (3, 2, 4), - 0, + (3, 3, 4), + 2, True, ), ( "4d_dynamic_ascending", - (2, 2, 1, 1), - (2, 2, 1, 2), + (2, 2, 1, 4), + (2, 2, 2, 4), (3, 3, 2, 4), 3, False, ), ( "4d_dynamic_descending_neg_dim", - (2, 2, 1, 1), - (2, 2, 1, 2), + (1, 3, 1, 1), + (2, 3, 2, 2), (3, 3, 2, 4), -3, True, @@ -79,6 +79,7 @@ def forward(self, x): Sort(), input_specs, output_dtypes=[torch.float, torch.int64], + use_dynamo_tracer=True, )