Skip to content

Commit

Permalink
[MLIR][TORCH] Fix aten.unsqueeze op (llvm#1578)
Browse files Browse the repository at this point in the history
The range of the unsqueeze dim is: [-input.dim() - 1, input.dim() + 1), the bug forgets to add 1.
  • Loading branch information
AmosLewis authored Nov 14, 2022
1 parent 6909eaf commit dfe7513
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
1 change: 1 addition & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@
"ArgmaxModule_with_dim",
"_LogSoftmaxModuleStable_basic",
"ElementwiseAtenWhereSelfModule_basic",
"ElementwiseUnsqueezeBroadcastModule_basic",
"LiftFreshCopyModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"ReduceSumDimIntListFloatModule_basic",
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2435,7 +2435,7 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant");

dim = toPositiveDim(dim, selfRank);
if (!isValidDim(dim, selfRank))
if (!isValidDim(dim, selfRank + 1))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

SmallVector<int64_t> outShape;
Expand Down
18 changes: 9 additions & 9 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -690,18 +690,18 @@ func.func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
// -----

// CHECK-LABEL: func.func @torch.aten.unsqueeze$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,1,3],si32> {
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [4, 1, 3]} : (tensor<4x3xi32>) -> tensor<4x1x3xi32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x1x3xi32> -> !torch.vtensor<[4,1,3],si32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,1,3],si32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [4, 3]} : (tensor<4x3xi32>) -> tensor<4x3x1xi32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32>
// CHECK: }

func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !torch.vtensor<[4,1,3],si32> {
%int1 = torch.constant.int 1
%0 = torch.aten.unsqueeze %arg0, %int1 : !torch.vtensor<[4,3],si32>, !torch.int -> !torch.vtensor<[4,1,3],si32>
return %0 : !torch.vtensor<[4,1,3],si32>
func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !torch.vtensor<[4,3,1],si32> {
%int2 = torch.constant.int 2
%0 = torch.aten.unsqueeze %arg0, %int2 : !torch.vtensor<[4,3],si32>, !torch.int -> !torch.vtensor<[4,3,1],si32>
return %0 : !torch.vtensor<[4,3,1],si32>
}

// -----
Expand Down

0 comments on commit dfe7513

Please sign in to comment.