Skip to content

Commit

Permalink
fix dim bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Oct 5, 2023
1 parent 016f266 commit d4f1451
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def aten_ops_amax(
SourceIR.ATEN,
name,
args[0],
args[1],
args_bounds_check(args, 1, replacement=[]),
args_bounds_check(args, 2, replacement=False),
)

Expand Down
7 changes: 5 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def amax(
):
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)

if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
dim = tuple(range(len(input_val.shape)))

layer = ctx.net.add_reduce(
input_val,
trt.ReduceOperation.MAX,
Expand All @@ -51,7 +54,7 @@ def sum(
):
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)

if dim is None:
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
dim = tuple(range(len(input_val.shape)))

layer = ctx.net.add_reduce(
Expand Down Expand Up @@ -169,7 +172,7 @@ def mean(
):
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)

if dim is None:
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
dim = tuple(range(len(input_val.shape)))

layer = ctx.net.add_reduce(
Expand Down
2 changes: 2 additions & 0 deletions tests/py/dynamo/conversion/test_amax_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def forward(self, x):

@parameterized.expand(
[
((1, 2, 4), [], True),
((3, 2, 4), [1], True),
((2, 1, 4, 5), [0, 3], True),
((2, 3, 4, 5), [0, 1, 2, 3], False),
Expand Down Expand Up @@ -72,6 +73,7 @@ def forward(self, x):

@parameterized.expand(
[
((1, 2, 4), [], True, torch.int, 0, 5),
((3, 2, 4), [1], True, torch.int, 0, 5),
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0),
Expand Down
4 changes: 3 additions & 1 deletion tests/py/dynamo/conversion/test_sum_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
class TestSumConverter(DispatchTestCase):
@parameterized.expand(
[
((1, 2),),
((3, 2, 4),),
((2, 3, 4, 5),),
((2, 3, 4, 5),),
((6, 7, 5, 4, 5),),
]
)
Expand Down Expand Up @@ -51,6 +51,7 @@ def forward(self, x):

@parameterized.expand(
[
((1, 2, 4), [], True),
((3, 2, 4), [1], True),
((2, 1, 4, 5), None, True),
((2, 3, 4, 5), [0, 1, 2, 3], False),
Expand Down Expand Up @@ -93,6 +94,7 @@ def forward(self, x):

@parameterized.expand(
[
((1, 2, 4), [], True, torch.int, 0, 5),
((3, 2, 4), [1], True, torch.int, 0, 5),
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
((2, 3, 4, 5), None, False, torch.int32, -5, 0),
Expand Down

0 comments on commit d4f1451

Please sign in to comment.