Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: dynamic shape support for any/sort/trunc ops #3026

Merged
merged 2 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2648,10 +2648,15 @@ def topk_validator(node: Node) -> bool:


def sort_validator(node: Node) -> bool:
shape = node.args[0].meta.get("tensor_meta").shape
meta_data = node.args[0].meta.get("tensor_meta")
if meta_data is None:
return False
shape = meta_data.shape
dim = node.args[1]
dim = get_positive_dim(dim, len(shape))
k = shape[dim]
if not isinstance(k, int):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If dim k is static and other dims are dynamic, we can support sort ops.
This is to validate if k is static or not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. Is it possible that k is -1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for bringing this to my attention.
I checked only dynamic dims are passed to export() in torch_tensorrt.dynamo.trace()
https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/_tracer.py#L81
So I think static dim value in meta_data will be valid dim value(>0). If you think static shape dim value from meta data could be -1, I will update to validate this case as well.

+elif isinstance(k, int) and k < 0:

  • return False

return False
return topk_sort_validator(k)


Expand Down Expand Up @@ -3103,7 +3108,9 @@ def aten_ops_topk(


@dynamo_tensorrt_converter(
torch.ops.aten.sort.default, capability_validator=sort_validator
torch.ops.aten.sort.default,
capability_validator=sort_validator,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
{
Expand All @@ -3128,7 +3135,7 @@ def aten_ops_sort(
)


@dynamo_tensorrt_converter(torch.ops.aten.trunc.default)
@dynamo_tensorrt_converter(torch.ops.aten.trunc.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down Expand Up @@ -3204,9 +3211,9 @@ def aten_ops_remainder(
)


@dynamo_tensorrt_converter(torch.ops.aten.any.default)
@dynamo_tensorrt_converter(torch.ops.aten.any.dim)
@dynamo_tensorrt_converter(torch.ops.aten.any.dims)
@dynamo_tensorrt_converter(torch.ops.aten.any.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.any.dim, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.any.dims, supports_dynamic_shapes=True)
def aten_ops_any(
ctx: ConversionContext,
target: Target,
Expand Down
12 changes: 9 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
flatten_dims,
get_axes_for_reduce_op,
get_positive_dim,
set_layer_name,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor
from torch_tensorrt.dynamo.types import TRTTensor
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM


def argmax_argmin(
Expand Down Expand Up @@ -155,9 +156,14 @@ def topk(
k,
get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))),
)

# topk layer supports dynamic k value but we cannot dertermin supported dynamic topk value at
# compile time.
assert k != DYNAMIC_DIM, "k value cannot be dynamic!"

# TensorRT ITopKLayer does not have a sorted flag, it is always returning the sorted topk elements
# so here no matter sorted is True or False the returned the topk Tensor object is always sorted
set_layer_name(topk_layer, target, name, source_ir)
set_layer_name(topk_layer, target, f"{name}_topk", source_ir)

if return_indices:
return topk_layer.get_output(0), topk_layer.get_output(1)
Expand Down
148 changes: 148 additions & 0 deletions tests/py/dynamo/conversion/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -184,5 +185,152 @@ def forward(self, x):
)


class TestAnyConverterDynamic(DispatchTestCase):
@parameterized.expand(
[
(
"3d_dynamic_float",
(2, 1, 1),
(2, 2, 1),
(3, 2, 4),
torch.float,
),
(
"2d_dynamic_int32",
(2, 2),
(2, 2),
(3, 2),
torch.int32,
),
(
"4d_dynamic_bool",
(1, 2, 1, 1),
(2, 2, 2, 2),
(2, 2, 4, 3),
torch.bool,
),
]
)
def test_any_dynamic(self, _, min_shape, opt_shape, max_shape, type):
class Any(nn.Module):
def forward(self, x):
return torch.ops.aten.any.default(x)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
Any(),
input_specs,
)

@parameterized.expand(
[
(
"3d_dynamic_dim_float",
(2, 1, 1),
(2, 2, 1),
(3, 2, 4),
torch.float,
2,
True,
),
(
"4d_dynamic_dim_int32",
(1, 1, 4, 1),
(2, 2, 4, 2),
(2, 4, 4, 3),
torch.int32,
-2,
False,
),
(
"3d_dynamic_dim_bool",
(2, 1, 1),
(2, 2, 1),
(3, 2, 4),
torch.bool,
0,
True,
),
]
)
def test_any_dynamic_dim(
self, _, min_shape, opt_shape, max_shape, type, dim, keep_dims
):
class AnyDim(nn.Module):
def forward(self, x):
return torch.ops.aten.any.dim(x, dim, keep_dims)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
AnyDim(),
input_specs,
)

@parameterized.expand(
[
(
"3d_dynamic_dims_float",
(2, 1, 1),
(2, 2, 1),
(3, 2, 4),
torch.float,
[1, 2],
True,
),
(
"4d_dynamic_dims_int32",
(1, 1, 4, 1),
(2, 2, 4, 2),
(2, 4, 4, 3),
torch.int32,
[2, -1],
False,
),
(
"3d_dynamic_dims_bool",
(1, 4, 1),
(2, 4, 2),
(4, 4, 3),
torch.bool,
[0, 1, 2],
False,
),
]
)
def test_any_dynamic_dims(
self, _, min_shape, opt_shape, max_shape, type, dims, keep_dims
):
class AnyDims(nn.Module):
def forward(self, x):
return torch.ops.aten.any.dims(x, dims, keep_dims)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
AnyDims(),
input_specs,
)


if __name__ == "__main__":
run_tests()
51 changes: 51 additions & 0 deletions tests/py/dynamo/conversion/test_sort_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -32,5 +33,55 @@ def forward(self, x):
)


class TestSortConverterDynamic(DispatchTestCase):
@parameterized.expand(
[
(
"3d_dynamic_descending",
(2, 1, 4),
(3, 2, 4),
(3, 3, 4),
2,
True,
),
(
"4d_dynamic_ascending",
(2, 2, 1, 4),
(2, 2, 2, 4),
(3, 3, 2, 4),
3,
False,
),
(
"4d_dynamic_descending_neg_dim",
(1, 3, 1, 1),
(2, 3, 2, 2),
(3, 3, 2, 4),
-3,
True,
),
]
)
def test_sort_dynamic(self, _, min_shape, opt_shape, max_shape, dim, descending):
class Sort(nn.Module):
def forward(self, x):
return torch.ops.aten.sort.default(x, dim, descending)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=torch.float,
),
]
self.run_test_with_dynamic_shape(
Sort(),
input_specs,
output_dtypes=[torch.float, torch.int64],
use_dynamo_tracer=True,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor meta data is available when use_dynamo_tracer=True

)


if __name__ == "__main__":
run_tests()
44 changes: 44 additions & 0 deletions tests/py/dynamo/conversion/test_trunc_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -48,5 +49,48 @@ def forward(self, input):
)


class TestTruncConverterDynamic(DispatchTestCase):
@parameterized.expand(
[
(
"3d_dynamic_int32",
(1, 1, 1),
(2, 2, 2),
(3, 4, 5),
torch.int32,
False,
),
(
"3d_dynamic_float32",
(2, 1, 1),
(2, 2, 2),
(2, 4, 5),
torch.float32,
True,
),
]
)
def test_trunc_dynamic(
self, _, min_shape, opt_shape, max_shape, type, enable_passes
):
class Trunc(nn.Module):
def forward(self, input):
return torch.ops.aten.trunc.default(input)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
Trunc(),
input_specs,
enable_passes=enable_passes,
)


if __name__ == "__main__":
run_tests()
Loading