Skip to content

Commit

Permalink
Implement decomposition from onnx.Sum to sequence of onnx.Add (#2964
Browse files Browse the repository at this point in the history
)

Signed-off-by: Sam <srcarroll314@gmail.com>
Co-authored-by: Alexandre Eichenberger <alexe@us.ibm.com>
  • Loading branch information
srcarroll and AlexandreEichenberger authored Oct 3, 2024
1 parent 56a610c commit 265ee60
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 0 deletions.
26 changes: 26 additions & 0 deletions src/Dialect/ONNX/Transforms/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,30 @@ struct GroupNormIntoLayerNormPattern2
}
};

/// Decompose `onnx.Sum` to a sequence of `onnx.Add`
struct SumToAddPattern : public OpRewritePattern<ONNXSumOp> {
using OpRewritePattern<ONNXSumOp>::OpRewritePattern;

LogicalResult matchAndRewrite(
ONNXSumOp sumOp, PatternRewriter &rewriter) const final {
SmallVector<Value> inputs(sumOp.getData_0());
assert(inputs.size() > 0 && "expected at least one input");
Value result = inputs[0];
if (inputs.size() > 1) {
inputs.erase(inputs.begin());
for (auto input : inputs) {
result = rewriter.create<ONNXAddOp>(sumOp.getLoc(), result, input);
}
}
auto resultType = mlir::cast<ShapedType>(sumOp.getResult().getType());
if (resultType != result.getType())
result = rewriter.create<ONNXCastOp>(
sumOp.getLoc(), resultType, result, 1, resultType.getElementType());
rewriter.replaceOp(sumOp, result);
return success();
}
};

// =============================================================================
// Pattern for replacing CastLikeOp by CastOp.
// =============================================================================
Expand Down Expand Up @@ -1093,6 +1117,7 @@ void DecomposeONNXToONNXPass::runOnOperation() {
target.addIllegalOp<ONNXSplitV11Op>();
target.addIllegalOp<ONNXSplitV13Op>();
target.addIllegalOp<ONNXSqueezeV11Op>();
target.addIllegalOp<ONNXSumOp>();
target.addIllegalOp<ONNXUnsqueezeV11Op>();
target.addIllegalOp<ONNXUpsampleOp>();
target.addIllegalOp<ONNXUpsampleV7Op>();
Expand Down Expand Up @@ -1165,6 +1190,7 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
patterns.insert<InstanceNormIntoLayerNormPattern>(context);
patterns.insert<GroupNormIntoLayerNormPattern1>(context);
patterns.insert<GroupNormIntoLayerNormPattern2>(context);
patterns.insert<SumToAddPattern>(context);

// TODO: consider whether to include SoftmaxPattern here
}
Expand Down
47 changes: 47 additions & 0 deletions test/mlir/onnx/onnx_decompose.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -698,3 +698,50 @@ func.func @test_castlike(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf16>) -> tensor
// CHECK: onnx.Return [[RES]] : tensor<*xf16>
}

// -----

func.func @test_sum(%arg0: tensor<128x10xf32>, %arg1: tensor<64x128x10xf32>, %arg2: tensor<10xf32>, %arg3: tensor<64x1x1xf32>) -> tensor<64x128x10xf32> {
%0 = "onnx.Sum"(%arg0, %arg1, %arg2, %arg3) : (tensor<128x10xf32>, tensor<64x128x10xf32>, tensor<10xf32>, tensor<64x1x1xf32>) -> tensor<64x128x10xf32>
onnx.Return %0 : tensor<64x128x10xf32>
// CHECK-LABEL: func @test_sum
// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}})
// CHECK-NEXT: %[[SUM0:.*]] = "onnx.Add"(%[[ARG0]], %[[ARG1]])
// CHECK-NEXT: %[[SUM1:.*]] = "onnx.Add"(%[[SUM0]], %[[ARG2]])
// CHECK-NEXT: %[[SUM2:.*]] = "onnx.Add"(%[[SUM1]], %[[ARG3]])
// CHECK-NEXT: onnx.Return %[[SUM2]]
}

// -----

func.func @test_sum_to_unranked(%arg0: tensor<128x10xf32>, %arg1: tensor<64x128x10xf32>, %arg2: tensor<10xf32>, %arg3: tensor<64x1x1xf32>) -> tensor<*xf32> {
%0 = "onnx.Sum"(%arg0, %arg1, %arg2, %arg3) : (tensor<128x10xf32>, tensor<64x128x10xf32>, tensor<10xf32>, tensor<64x1x1xf32>) -> tensor<*xf32>
onnx.Return %0 : tensor<*xf32>
// CHECK-LABEL: func @test_sum
// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}})
// CHECK-NEXT: %[[SUM0:.*]] = "onnx.Add"(%[[ARG0]], %[[ARG1]])
// CHECK-NEXT: %[[SUM1:.*]] = "onnx.Add"(%[[SUM0]], %[[ARG2]])
// CHECK-NEXT: %[[SUM2:.*]] = "onnx.Add"(%[[SUM1]], %[[ARG3]])
// CHECK-NEXT: %[[CAST:.*]] = "onnx.Cast"(%[[SUM2]]) {saturate = 1 : si64, to = f32} : (tensor<64x128x10xf32>) -> tensor<*xf32>
// CHECK-NEXT: onnx.Return %[[CAST]]
}

// -----

func.func @test_sum_single_input(%arg0: tensor<64x128x10xf32>) -> tensor<64x128x10xf32> {
%0 = "onnx.Sum"(%arg0) : (tensor<64x128x10xf32>) -> tensor<64x128x10xf32>
onnx.Return %0 : tensor<64x128x10xf32>
// CHECK-LABEL: func @test_sum_single_input
// CHECK-SAME: (%[[ARG0:.*]]: {{.*}})
// CHECK-NEXT: onnx.Return %[[ARG0]]
}

// -----

func.func @test_sum_single_input_to_unranked(%arg0: tensor<64x128x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Sum"(%arg0) : (tensor<64x128x10xf32>) -> tensor<*xf32>
onnx.Return %0 : tensor<*xf32>
// CHECK-LABEL: func @test_sum_single_input_to_unranked
// CHECK-SAME: (%[[ARG0:.*]]: {{.*}})
// CHECK-NEXT: %[[CAST:.*]] = "onnx.Cast"(%[[ARG0]]) {saturate = 1 : si64, to = f32} : (tensor<64x128x10xf32>) -> tensor<*xf32>
// CHECK-NEXT: onnx.Return %[[CAST]]
}

0 comments on commit 265ee60

Please sign in to comment.