Skip to content

Commit

Permalink
feat: support flatten and reshape via shuffle_layer
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Sep 29, 2023
1 parent 0d402fb commit b414928
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 30 deletions.
38 changes: 38 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,3 +1413,41 @@ def aten_ops_linear(
weight=args[1],
bias=args_bounds_check(args, 2, None),
)


@dynamo_tensorrt_converter(torch.ops.aten.reshape.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.view.default) # type: ignore[misc]
def aten_ops_reshape(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.shuffle.reshape(
network,
target,
SourceIR.ATEN,
name,
input=args[0],
shape=args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.flatten.using_ints) # type: ignore[misc]
def aten_ops_flatten(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.shuffle.flatten(
network,
target,
SourceIR.ATEN,
name,
input=args[0],
start_dim=args_bounds_check(args, 1, 0),
end_dim=args_bounds_check(args, 2, -1),
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
reduce,
select,
shape,
shuffle,
slice,
split,
squeeze,
Expand Down
52 changes: 52 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import List, Optional, Union

from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
from torch_tensorrt.fx.converters.converter_utils import (
get_positive_dim,
set_layer_name,
)
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor


def reshape(
network: TRTNetwork,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
shape: List[int],
) -> TRTTensor:
layer = network.add_shuffle(input)
layer.reshape_dims = tuple(shape)
set_layer_name(layer, target, f"{name}_reshape", source_ir)
return layer.get_output(0)


def flatten(
network: TRTNetwork,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
start_dim: int,
end_dim: int,
) -> TRTTensor:
shape = input.shape
dim_size = len(shape)
start_dim = get_positive_dim(start_dim, dim_size)
end_dim = get_positive_dim(end_dim, dim_size)

num_elements = 1
for i in range(start_dim, end_dim + 1):
num_elements *= shape[i]

new_shape = (
tuple(shape[:start_dim])
+ (num_elements,)
+ (tuple(shape[end_dim + 1 :]) if end_dim + 1 < dim_size else tuple())
)
layer = network.add_shuffle(input)
layer.reshape_dims = new_shape
set_layer_name(layer, target, f"{name}_flatten", source_ir)
return layer.get_output(0)
30 changes: 0 additions & 30 deletions tests/py/dynamo/conversion/test_reshape_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,36 +69,6 @@ def forward(self, x):
expected_ops={torch.ops.aten.view.default},
)

@unittest.skipIf(
trt.__version__ < "8.5",
"Shape tensor supported well in TensorRT 8.5 and later",
)
def test_reshape_with_dynamic_shape_size(self):
class TestModule(torch.nn.Module):
def forward(self, x, y):
shape_y = y.shape
t = shape_y[1]
return torch.reshape(x, [-1, t, 3])

input_specs = [
Input(
shape=(-1, 5, 6),
dtype=torch.float32,
shape_ranges=[((1, 5, 6), (3, 5, 6), (3, 5, 6))],
),
Input(
shape=(-1, 5),
dtype=torch.float32,
shape_ranges=[((1, 5), (3, 5), (3, 5))],
),
]

self.run_test_with_dynamic_shape(
TestModule(),
input_specs,
expected_ops={torch.ops.aten.view.default},
)


if __name__ == "__main__":
run_tests()

0 comments on commit b414928

Please sign in to comment.