diff --git a/src/Dialect/ONNX/Transforms/Decompose.cpp b/src/Dialect/ONNX/Transforms/Decompose.cpp index 91a4f137f5..0714523fb5 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.cpp +++ b/src/Dialect/ONNX/Transforms/Decompose.cpp @@ -985,6 +985,30 @@ struct GroupNormIntoLayerNormPattern2 } }; +/// Decompose `onnx.Sum` to a sequence of `onnx.Add` +struct SumToAddPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ONNXSumOp sumOp, PatternRewriter &rewriter) const final { + SmallVector inputs(sumOp.getData_0()); + assert(inputs.size() > 0 && "expected at least one input"); + Value result = inputs[0]; + if (inputs.size() > 1) { + inputs.erase(inputs.begin()); + for (auto input : inputs) { + result = rewriter.create(sumOp.getLoc(), result, input); + } + } + auto resultType = mlir::cast(sumOp.getResult().getType()); + if (resultType != result.getType()) + result = rewriter.create( + sumOp.getLoc(), resultType, result, 1, resultType.getElementType()); + rewriter.replaceOp(sumOp, result); + return success(); + } +}; + // ============================================================================= // Pattern for replacing CastLikeOp by CastOp. // ============================================================================= @@ -1093,6 +1117,7 @@ void DecomposeONNXToONNXPass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -1165,6 +1190,7 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns( patterns.insert(context); patterns.insert(context); patterns.insert(context); + patterns.insert(context); // TODO: consider whether to include SoftmaxPattern here } diff --git a/test/mlir/onnx/onnx_decompose.mlir b/test/mlir/onnx/onnx_decompose.mlir index f4de9145b7..2fe2f9e374 100644 --- a/test/mlir/onnx/onnx_decompose.mlir +++ b/test/mlir/onnx/onnx_decompose.mlir @@ -698,3 +698,50 @@ func.func @test_castlike(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf16>) -> tensor // CHECK: onnx.Return [[RES]] : tensor<*xf16> } +// ----- + +func.func @test_sum(%arg0: tensor<128x10xf32>, %arg1: tensor<64x128x10xf32>, %arg2: tensor<10xf32>, %arg3: tensor<64x1x1xf32>) -> tensor<64x128x10xf32> { + %0 = "onnx.Sum"(%arg0, %arg1, %arg2, %arg3) : (tensor<128x10xf32>, tensor<64x128x10xf32>, tensor<10xf32>, tensor<64x1x1xf32>) -> tensor<64x128x10xf32> + onnx.Return %0 : tensor<64x128x10xf32> + // CHECK-LABEL: func @test_sum + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}) + // CHECK-NEXT: %[[SUM0:.*]] = "onnx.Add"(%[[ARG0]], %[[ARG1]]) + // CHECK-NEXT: %[[SUM1:.*]] = "onnx.Add"(%[[SUM0]], %[[ARG2]]) + // CHECK-NEXT: %[[SUM2:.*]] = "onnx.Add"(%[[SUM1]], %[[ARG3]]) + // CHECK-NEXT: onnx.Return %[[SUM2]] +} + +// ----- + +func.func @test_sum_to_unranked(%arg0: tensor<128x10xf32>, %arg1: tensor<64x128x10xf32>, %arg2: tensor<10xf32>, %arg3: tensor<64x1x1xf32>) -> tensor<*xf32> { + %0 = "onnx.Sum"(%arg0, %arg1, %arg2, %arg3) : (tensor<128x10xf32>, tensor<64x128x10xf32>, tensor<10xf32>, tensor<64x1x1xf32>) -> tensor<*xf32> + onnx.Return %0 : tensor<*xf32> + // CHECK-LABEL: func @test_sum + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}) + // CHECK-NEXT: %[[SUM0:.*]] = "onnx.Add"(%[[ARG0]], %[[ARG1]]) + // CHECK-NEXT: %[[SUM1:.*]] = "onnx.Add"(%[[SUM0]], %[[ARG2]]) + // CHECK-NEXT: %[[SUM2:.*]] = "onnx.Add"(%[[SUM1]], %[[ARG3]]) + // CHECK-NEXT: %[[CAST:.*]] = "onnx.Cast"(%[[SUM2]]) {saturate = 1 : si64, to = f32} : (tensor<64x128x10xf32>) -> tensor<*xf32> + // CHECK-NEXT: onnx.Return %[[CAST]] +} + +// ----- + +func.func @test_sum_single_input(%arg0: tensor<64x128x10xf32>) -> tensor<64x128x10xf32> { + %0 = "onnx.Sum"(%arg0) : (tensor<64x128x10xf32>) -> tensor<64x128x10xf32> + onnx.Return %0 : tensor<64x128x10xf32> + // CHECK-LABEL: func @test_sum_single_input + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}) + // CHECK-NEXT: onnx.Return %[[ARG0]] +} + +// ----- + +func.func @test_sum_single_input_to_unranked(%arg0: tensor<64x128x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Sum"(%arg0) : (tensor<64x128x10xf32>) -> tensor<*xf32> + onnx.Return %0 : tensor<*xf32> + // CHECK-LABEL: func @test_sum_single_input_to_unranked + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}) + // CHECK-NEXT: %[[CAST:.*]] = "onnx.Cast"(%[[ARG0]]) {saturate = 1 : si64, to = f32} : (tensor<64x128x10xf32>) -> tensor<*xf32> + // CHECK-NEXT: onnx.Return %[[CAST]] +}