Skip to content

Commit

Permalink
dynamic shape for slice converter (pytorch#2901)
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose authored Jul 11, 2024
1 parent c0a2bea commit 0ef880d
Show file tree
Hide file tree
Showing 2 changed files with 315 additions and 15 deletions.
170 changes: 161 additions & 9 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import sys
from typing import Optional, Sequence

import numpy as np
Expand All @@ -14,6 +15,11 @@
get_trt_tensor,
)
from torch_tensorrt.dynamo.conversion.impl.cat import cat
from torch_tensorrt.dynamo.conversion.impl.elementwise import floor_divide
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import (
convert_binary_elementwise,
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
Expand All @@ -36,29 +42,175 @@ def slice_op( # TODO: This should be slice not whatever is in base
stop: Optional[int],
step: int,
) -> TRTTensor:
# check if dim is same as dynamic shape dimension
# this is required when stop is ITensor
dynamic_input_dim_equal = False
for i in range(len(input.shape)):
if input.shape[i] == DYNAMIC_DIM and i == dim:
dynamic_input_dim_equal = True

# Special case for start being None
if start is None:
start = 0

# Special case for stop being None
stop_dynamic_None = False
if stop is None:
stop = input.shape[dim]
stop_dynamic_None = True if input.shape[dim] == -1 else False
stop = 0 if input.shape[dim] == -1 else input.shape[dim]

dim = get_positive_dim(dim, len(input.shape))
start = get_positive_dim(start, input.shape[dim])
stop = get_positive_dim(stop, input.shape[dim])

if has_dynamic_shape(input.shape):
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!"
# Assign the initial start tensor
start_slice = []
# the add_slice will take care of dynamic input shape cases here
if isinstance(start, int):
start_slice = [0] * len(input.shape)
start_slice[dim] = start
else:
for i in range(len(input.shape)):
start_slice.append(0) if i != dim else start_slice.append(start)

# Assign the initial stop tensor
stop_slice = []
if isinstance(stop, int) and dynamic_input_dim_equal:
stop_slice = input.shape
stop_slice[dim] = stop
else:
# required for cases where stop is ITensor and dim != dynamic dim of input
# not required for cases where stop is negative and dim != dynamic dim of inpu
for i in range(len(input.shape)):
if input.shape[i] == DYNAMIC_DIM and i != dim:
stop_slice.append(
get_shape(
ctx, target, source_ir, name + f"_shape_dim_stop_{i}", input, i
)
)
elif i == dim:
stop_slice.append(stop)
else:
stop_slice.append(input.shape[i])

start_slice = [0] * len(input.shape)
start_slice[dim] = start
stride_slice = [1] * len(input.shape)
stride_slice[dim] = step
output_shape = list(input.shape)
output_shape[dim] = math.ceil((stop - start) / step)

if input.shape[dim] != -1 and isinstance(start, int) and isinstance(stop, int):
start = get_positive_dim(start, input.shape[dim])
stop = get_positive_dim(stop, input.shape[dim])
start_slice[dim] = start
else:
# the start and stop or None is dynamic along dim or or start or stop is an ITensor
if (
not (isinstance(start, int))
or not (isinstance(stop, int))
or start < 0
or stop < 0
or stop_dynamic_None
or stop == sys.maxsize
):
# special assignments for dynamic cases
if isinstance(start, int) and start < 0:
start_slice = input.shape
start_slice[dim] = -1 * start
if (isinstance(stop, int) and stop < 0) or stop_dynamic_None:
stop_slice = [0] * len(input.shape)
stop_slice[dim] = -1 * stop
if stop == sys.maxsize:
stop_slice = [0] * len(input.shape)
start_slice_tensor = cat(
ctx,
target,
source_ir,
name + "_start_slice_concat",
tuple(start_slice),
0,
cast_dtype=trt.int32,
)
stop_slice_tensor = cat(
ctx,
target,
source_ir,
name + "_stop_slice_concat",
tuple(stop_slice),
0,
cast_dtype=trt.int32,
)
stride_slice_tensor = cat(
ctx,
target,
source_ir,
name + "_stride_slice_concat",
tuple(stride_slice),
0,
cast_dtype=trt.int32,
)

if isinstance(start, int) and start < 0:
shape = get_shape_with_dynamic_shape(
ctx, target, source_ir, name, output_shape, input
)
start_slice_tensor = convert_binary_elementwise(
ctx,
target,
source_ir,
name + "_sub_start",
trt.ElementWiseOperation.SUB,
shape,
start_slice_tensor,
)
if isinstance(stop, int) and (
(stop < 0) or stop_dynamic_None or stop == sys.maxsize
):
shape = get_shape_with_dynamic_shape(
ctx, target, source_ir, name, output_shape, input
)
stop_slice_tensor = convert_binary_elementwise(
ctx,
target,
source_ir,
name + "_sub_stop",
trt.ElementWiseOperation.SUB,
shape,
stop_slice_tensor,
)

# this is required for the ceil operation
output_shape_tensor_num = convert_binary_elementwise(
ctx,
target,
source_ir,
name + "_sub_num",
trt.ElementWiseOperation.SUB,
start_slice_tensor,
stop_slice_tensor,
)
output_shape_tensor_neg = floor_divide(
ctx,
target,
source_ir,
name + "_div",
output_shape_tensor_num,
stride_slice_tensor,
)
output_shape_tensor = convert_binary_elementwise(
ctx,
target,
source_ir,
name + "_prod",
trt.ElementWiseOperation.PROD,
output_shape_tensor_neg,
-1,
)
layer = ctx.net.add_slice(
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
)
layer.set_input(1, start_slice_tensor)
layer.set_input(2, output_shape_tensor)
layer.set_input(3, stride_slice_tensor)
return layer.get_output(0)

output_shape[dim] = math.ceil((stop - start) / step)
return slice(
ctx, target, source_ir, name, input, start_slice, output_shape, stride_slice
)
Expand Down
160 changes: 154 additions & 6 deletions tests/py/dynamo/conversion/test_slice_aten.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from torch_tensorrt import Input

from .harness import DispatchTestCase
Expand Down Expand Up @@ -53,11 +52,159 @@ def forward(self, input):
class TestSliceConverterDynamicShape(DispatchTestCase):
@parameterized.expand(
[
("slice_dim_start_stop_step", 1, 0, 7, 2),
("slice_dim_start_stop_step", 1, 0, 10, 2),
(
"slice_dynamic_dim_start_stop_step_offset",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
1,
0,
7,
2,
),
(
"slice_dynamic_dim_start_stop_step",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
1,
0,
10,
2,
),
(
"slice_dynamic_dim_start_stop_step_negatives",
(1, 10, 10),
(10, 10, 10),
(10, 10, 10),
-2,
-2,
-1,
1,
),
(
"slice_dim_start_stop_step_max_int",
(1, 10, 10),
(10, 10, 10),
(10, 10, 10),
2,
0,
2**63 - 1,
1,
),
(
"slice_dim_start_stop_step_past_end",
(1, 10, 10),
(10, 10, 10),
(10, 10, 10),
2,
0,
2048,
1,
),
(
"slice_dim_start_stop_step_none",
(1, 10, 10),
(10, 10, 10),
(10, 10, 10),
2,
None,
None,
1,
),
(
"slice_dynamic_dim_start_stop_step_offset_4D",
(1, 10, 1, 3),
(1, 10, 10, 3),
(1, 10, 10, 3),
1,
0,
7,
2,
),
(
"slice_dynamic_dim_start_stop_step_4D",
(1, 10, 1, 3),
(1, 10, 10, 3),
(1, 10, 10, 3),
1,
0,
10,
2,
),
(
"slice_dynamic_dim_dyn_start_dyn_stop_step",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
2,
-2,
10,
2,
),
(
"slice_dynamic_dim_dyn_start_stop_dyn_step",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
2,
0,
-2,
2,
),
(
"slice_dynamic_dim_dyn_start_stop_None_step",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
2,
0,
None,
2,
),
(
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
2,
-8,
-2,
2,
),
(
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step_ceil",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
2,
-9,
-2,
2,
),
(
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step_diff_dim",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
0,
-8,
-2,
2,
),
(
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step_diff_dim_ceil",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
0,
-9,
-2,
2,
),
]
)
def test_slice(self, _, dim, start, stop, step):
def test_slice(self, _, min_shape, opt_shape, max_shape, dim, start, stop, step):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -68,9 +215,10 @@ def forward(self, input):

input_specs = [
Input(
shape=(1, 10, -1),
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=torch.float32,
shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))],
),
]
self.run_test_with_dynamic_shape(
Expand Down

0 comments on commit 0ef880d

Please sign in to comment.