From 7c5875185f3452d97f16ad11af80e50edc0ecf53 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Fri, 4 Oct 2024 23:20:55 +0900 Subject: [PATCH] Recompose multiple ops into a single ONNXGelu (#2965) Recompose multiple ops into a single ONNXGelu (#2965) Signed-off-by: Tung D. Le --------- Signed-off-by: Tung D. Le --- src/Dialect/ONNX/DialectBuilder.cpp | 5 + src/Dialect/ONNX/DialectBuilder.hpp | 3 + src/Dialect/ONNX/ONNXOps/OpHelper.cpp | 18 ++ src/Dialect/ONNX/ONNXOps/OpHelper.hpp | 43 +++++ src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc | 62 +++++++ src/Dialect/ONNX/Transforms/ConstProp.cpp | 17 -- src/Dialect/ONNX/Transforms/Recompose.cpp | 217 +++++++++++++++++++++- test/mlir/onnx/onnx_recompose.mlir | 135 ++++++++++++++ 8 files changed, 480 insertions(+), 20 deletions(-) diff --git a/src/Dialect/ONNX/DialectBuilder.cpp b/src/Dialect/ONNX/DialectBuilder.cpp index 4faaff6dfb..b9382a06b0 100644 --- a/src/Dialect/ONNX/DialectBuilder.cpp +++ b/src/Dialect/ONNX/DialectBuilder.cpp @@ -150,6 +150,11 @@ Value OnnxBuilder::expand(Type outputType, Value input, Value shape) const { outputType, toTensor(input), toTensor(shape)); } +Value OnnxBuilder::gelu(Value input, StringAttr approximateAttr) const { + return createOpAndInferShapes( + toTensor(input.getType()), input, approximateAttr); +} + // ONNXLayerNormalizationOp, version with one output only (Y). Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale, Value bias, int64_t axis, FloatAttr epsilon) const { diff --git a/src/Dialect/ONNX/DialectBuilder.hpp b/src/Dialect/ONNX/DialectBuilder.hpp index 8f6b0931e3..6bc31974f2 100644 --- a/src/Dialect/ONNX/DialectBuilder.hpp +++ b/src/Dialect/ONNX/DialectBuilder.hpp @@ -87,6 +87,9 @@ struct OnnxBuilder : DialectBuilder { mlir::Value expand( mlir::Type outputType, mlir::Value input, mlir::Value shape) const; + // ONNXGeluOp + mlir::Value gelu(mlir::Value input, mlir::StringAttr approximateAttr) const; + // ONNXLayerNormalizationOp, version with one output only (Y). mlir::Value layerNorm(mlir::Type outputType, mlir::Value input, mlir::Value scale, mlir::Value bias, int64_t axis, diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp index 3a6be76ff2..520de56339 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp @@ -579,6 +579,24 @@ RESULT_TYPE getScalarValue(ONNXConstantOp constantOp) { template double getScalarValue(ONNXConstantOp constantOp); template int64_t getScalarValue(ONNXConstantOp constantOp); +/// Return the wide type of a value. +WideNum asWideNum(double n, Type elemType) { + return wideZeroDispatch(elemType, [n](auto wideZero) { + using cpptype = decltype(wideZero); + constexpr BType TAG = toBType; + return WideNum::widen(static_cast(n)); + }); +} + +/// Checks whether a constant tensor's elements are all equal to a given scalar. +bool isConstOf(Value constValue, double n) { + ElementsAttr constElements = getElementAttributeFromONNXValue(constValue); + Type elemType = constElements.getElementType(); + assert(!elemType.isInteger(1) && "booleans are not supported"); + WideNum w = asWideNum(n, elemType); + return ElementsAttrBuilder::allEqual(constElements, w); +} + // Convert type to MLIR type. // A complete list of types can be found in: // /third_party/onnx/onnx/onnx.pb.h diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp index 3d827f85d5..b084ad5cd6 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp @@ -244,6 +244,12 @@ RESULT_TYPE getScalarValue(mlir::ElementsAttr denseAttr, mlir::Type type); template RESULT_TYPE getScalarValue(mlir::ONNXConstantOp constantOp); +/// Return the wide type of a value. +WideNum asWideNum(double n, mlir::Type elemType); + +/// Checks whether a constant tensor's elements are all equal to a given scalar. +bool isConstOf(mlir::Value constValue, double n); + mlir::Type convertONNXTypeToMLIRType( mlir::Builder &builder, onnx::TensorProto_DataType onnxType); @@ -277,6 +283,43 @@ bool operandOfOpDefinedBy(mlir::Operation *&matchOp, mlir::Operation *op, mlir::Value &matchOperand0, mlir::Value &matchOperand1, int64_t matchThisOperandIndex); +// This is to recognize a binary op, e.g. A*B where one of A and B is a constant +// and the other one is defined by OP. +// Note: this function can handle the communitive property of the binary op. +// +// For example, to recognize this pattern: +// %x = "onnx.Tanh"() +// %y = 0.5 * %x // or %x * 0.5 +// +// we call +// ``` +// ONNXTanhOp tanhOp; +// bool found = matchConstAndOp(A, B, 0.5, tanhOp); +// ``` +// where `A` and `B` are operands of ONNXMul that produces %y. +template +bool matchConstAndOp(mlir::Value A, mlir::Value B, double cst, OP &op); + +// This is to recognize a binary op, e.g. A*B where one of A and B is the given +// value and the other one is defined by OP. +// Note: this function can handle the communitive property of the binary op. +// +// For example, to recognize this pattern where %z is one of the inputs of *, +// and the other input of * is defined by onnx.Tanh: +// %x = "onnx.Tanh"() +// %y = %z * %x // or %x * %z +// +// we call +// ``` +// Value z; +// ONNXTanhOp tanhOp; +// bool found = matchConstAndOp(A, B, z, tanhOp); +// ``` +// where `A` and `B` are operands of ONNXMul that produces %y. +template +bool matchValueAndOp( + mlir::Value A, mlir::Value B, mlir::Value matchValue, OP &matchOp); + /// Check if a value is to store dimensions, meaning it is a tensor of one /// element or concatenation of one-element tensors. bool areDims(mlir::Value val); diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc index e301cb19de..c4e961e755 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc @@ -83,3 +83,65 @@ bool operandOfOpDefinedBy(mlir::Operation *&matchOp, mlir::Operation *op, } return false; } + +// This is to recognize a binary op, e.g. A*B where one of A and B is a constant +// and the other one is defined by OP. +// Note: this function can handle the communitive property of the binary op. +// +// For example, to recognize this pattern: +// %x = "onnx.Tanh"() +// %y = 0.5 * %x // or %x * 0.5 +// +// we call +// ``` +// ONNXTanhOp tanhOp; +// bool found = matchConstAndOp(A, B, 0.5, tanhOp); +// ``` +// where `A` and `B` are operands of ONNXMul that produces %y. +template +bool matchConstAndOp(mlir::Value A, mlir::Value B, double cst, OP &matchOp) { + auto opA = A.getDefiningOp(); + auto opB = B.getDefiningOp(); + if (onnx_mlir::isDenseONNXConstant(A) && onnx_mlir::isConstOf(A, cst) && opB) + { + matchOp = opB; + return true; + } + if (opA && onnx_mlir::isDenseONNXConstant(B) && onnx_mlir::isConstOf(B, cst)) + { + matchOp = opA; + return true; + } + return false; +} + +// This is to recognize a binary op, e.g. A*B where one of A and B is the given +// value and the other one is defined by OP. +// Note: this function can handle the communitive property of the binary op. +// +// For example, to recognize this pattern where %z is one of the inputs of *, +// and the other input of * is defined by onnx.Tanh: +// %x = "onnx.Tanh"() +// %y = %z * %x // or %x * %z +// +// we call +// ``` +// Value z; +// ONNXTanhOp tanhOp; +// bool found = matchConstAndOp(A, B, z, tanhOp); +// ``` +// where `A` and `B` are operands of ONNXMul that produces %y. +template +bool matchValueAndOp(mlir::Value A, mlir::Value B, mlir::Value matchValue, OP &matchOp) { + auto opA = A.getDefiningOp(); + auto opB = B.getDefiningOp(); + if ((A == matchValue) && opB) { + matchOp = opB; + return true; + } + if (opA && (B == matchValue)) { + matchOp = opA; + return true; + } + return false; +} diff --git a/src/Dialect/ONNX/Transforms/ConstProp.cpp b/src/Dialect/ONNX/Transforms/ConstProp.cpp index 49d4855042..e95632b3e3 100644 --- a/src/Dialect/ONNX/Transforms/ConstProp.cpp +++ b/src/Dialect/ONNX/Transforms/ConstProp.cpp @@ -186,23 +186,6 @@ Value createMinimumValueForClip( llvm::APFloat::getLargest, true, llvm::APInt::getMinValue); } -WideNum asWideNum(double n, Type elemType) { - return wideZeroDispatch(elemType, [n](auto wideZero) { - using cpptype = decltype(wideZero); - constexpr BType TAG = toBType; - return WideNum::widen(static_cast(n)); - }); -} - -/// Checks whether a constant tensor's elements are all equal to a given scalar. -bool isConstOf(Value constValue, double n) { - ElementsAttr constElements = getConstValueElements(constValue); - Type elemType = constElements.getElementType(); - assert(!elemType.isInteger(1) && "booleans are not supported"); - WideNum w = asWideNum(n, elemType); - return ElementsAttrBuilder::allEqual(constElements, w); -} - // Extracts number from a scalar constant value. WideNum getScalarNum(Value constValue) { ElementsAttr elements = getConstValueElements(constValue); diff --git a/src/Dialect/ONNX/Transforms/Recompose.cpp b/src/Dialect/ONNX/Transforms/Recompose.cpp index 7144611bbd..5b57620635 100644 --- a/src/Dialect/ONNX/Transforms/Recompose.cpp +++ b/src/Dialect/ONNX/Transforms/Recompose.cpp @@ -340,10 +340,213 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern { } }; +struct RecomposeGeluFromMulPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ONNXMulOp mulOp, PatternRewriter &rewriter) const final { + using namespace onnx_mlir; + Location loc = mulOp.getLoc(); + // Match: + // - for exact gelu + // gelu(x) = 0.5 * x * (1 + erf(x/1.41421354)) + // where 1.41421354 is sqrt(2). + // + // or + // + // - for approximate gelu + // gelu(x) = 0.5 * x * (1 + tanh[0.797884583 * (x + 0.044715 * x^3)]) + // where 0.797884583 is sqrt(2/pi). + Value x; + bool isExactGelu = false; + if (!matchGeluPattern(mulOp, x, isExactGelu)) + return failure(); + + // Replace + MultiDialectBuilder create(rewriter, loc); + StringAttr approximateAttr = + rewriter.getStringAttr(isExactGelu ? "none" : "tanh"); + Value res = create.onnx.gelu(x, approximateAttr); + rewriter.replaceOp(mulOp, res); + return success(); + } + + static bool matchGeluPattern(ONNXMulOp mulOp, Value &x, bool &isExactGelu) { + using namespace onnx_mlir; + // Subgraph to match: + // - for exact gelu + // gelu(x) = 0.5 * x * (1 + erf(x/1.41421354)) + // where 1.41421354 is sqrt(2). + // + // or + // + // - for approximate gelu + // gelu(x) = 0.5 * x * (1 + tanh[0.797884583 * (x + 0.044715 * x^3)]) + // where 0.797884583 is sqrt(2/pi). + // + // Associcative and communitative properties are handled. + + // Helper function. + auto constOf = [](Value v, double n) { + return isDenseONNXConstant(v) && isConstOf(v, n); + }; + + // Match 0.5 * a * b + // Two associative cases depending on which Mul 0.5 belongs to: + // - 0.5 * (a * b) + // - (0.5 * a) * b + // For each case, we have four communitive cases: 2 for the outer Mul and 2 + // for the inner Mul. In total, we handle 8 cases. + Value lhs = mulOp.getOperand(0); + Value rhs = mulOp.getOperand(1); + + Value fstMulVal, sndMulVal; + bool foundHalf = false; + + ONNXMulOp innerMulOp; + if (matchConstAndOp(lhs, rhs, 0.5, innerMulOp)) { + // - 0.5 * (a * b) or (a * b) * 0.5 + fstMulVal = innerMulOp.getOperand(0); + sndMulVal = innerMulOp.getOperand(1); + foundHalf = true; + } + if (!foundHalf && !constOf(lhs, 0.5) && !constOf(rhs, 0.5)) { + if (auto lhsMulOp = lhs.getDefiningOp()) { + // - (0.5 * a) * b + Value l = lhsMulOp.getOperand(0); + Value r = lhsMulOp.getOperand(1); + if (constOf(l, 0.5)) { + fstMulVal = r; + sndMulVal = rhs; + foundHalf = true; + } else if (constOf(r, 0.5)) { + fstMulVal = l; + sndMulVal = rhs; + foundHalf = true; + } + } + if (!foundHalf) { + if (auto rhsMulOp = rhs.getDefiningOp()) { + // - b * (0.5 * a) + Value l = rhsMulOp.getOperand(0); + Value r = rhsMulOp.getOperand(1); + if (constOf(l, 0.5)) { + fstMulVal = lhs; + sndMulVal = r; + foundHalf = true; + } else if (constOf(r, 0.5)) { + fstMulVal = lhs; + sndMulVal = l; + foundHalf = true; + } + } + } + } + if (!foundHalf) + return reportFailure("missing 0.5 * a * b"); + + // Exact gelu. + // Match 1 + erf() + bool foundErf = false; + ONNXErfOp erfOp; + // Try the first operand. + if (auto add1Op = fstMulVal.getDefiningOp()) { + foundErf = matchConstAndOp( + add1Op.getOperand(0), add1Op.getOperand(1), 1.0, erfOp); + if (foundErf) + x = sndMulVal; + } + if (!foundErf) { + // Try the second operand. + if (auto add1Op = sndMulVal.getDefiningOp()) { + foundErf = matchConstAndOp( + add1Op.getOperand(0), add1Op.getOperand(1), 1.0, erfOp); + if (foundErf) + x = fstMulVal; + } + } + if (foundErf) { + // gelu(x) = 0.5 * x * (1 + erf(x/1.41421354)) + Value erfInput = erfOp.getOperand(); + auto divOp = erfInput.getDefiningOp(); + if (!divOp) + return reportFailure("[Exact] missing div op"); + if (divOp.getOperand(0) != x) + return reportFailure("[Exact] missing x in x/1.41421354"); + if (!constOf(divOp.getOperand(1), 1.41421354)) + return reportFailure("[Exact] missing 1.41421354"); + isExactGelu = true; + return true; + } else { + // Do not return here, we still check the approximate case. + reportFailure("[Exact] missing (1 + erf)"); + } + + // Approximate gelu. + // gelu(x) = 0.5 * x * (1 + tanh[0.797884583 * (x + 0.044715 * x^3)]) + // Match 1 + tanh() + bool foundTanh = false; + ONNXTanhOp tanhOp; + // Try the first operand. + if (auto add1Op = fstMulVal.getDefiningOp()) { + foundTanh = matchConstAndOp( + add1Op.getOperand(0), add1Op.getOperand(1), 1.0, tanhOp); + if (foundTanh) + x = sndMulVal; + } + if (!foundTanh) { + // Try the second operand. + if (auto add1Op = sndMulVal.getDefiningOp()) { + foundTanh = matchConstAndOp( + add1Op.getOperand(0), add1Op.getOperand(1), 1.0, tanhOp); + if (foundTanh) + x = fstMulVal; + } + } + if (!foundTanh) + return reportFailure("[Approximate] missing (1 + tanh)"); + + // Match 0.797884583 * (x + 0.044715 * x^3) + auto mul1Op = tanhOp.getOperand().getDefiningOp(); + if (!mul1Op) + return reportFailure("[Approximate] missing mul op for (0.797884583 *)"); + ONNXAddOp add2Op; + if (!matchConstAndOp( + mul1Op.getOperand(0), mul1Op.getOperand(1), 0.797884583, add2Op)) + return reportFailure( + "[Approximate] missing add op for (x + 0.044715*x^3))"); + + // Match x + 0.044715 * x^3 + ONNXMulOp mul2Op; + if (!matchValueAndOp( + add2Op.getOperand(0), add2Op.getOperand(1), x, mul2Op)) + return reportFailure("[Approximate] missing mul op for 0.044715 * x^3"); + + // Match 0.044715 * x^3 + ONNXPowOp powOp; + if (!matchConstAndOp( + mul2Op.getOperand(0), mul2Op.getOperand(1), 0.044715, powOp)) + return reportFailure("[Approximate] missing 0.044715 and/or pow op"); + + // Match x^3 + lhs = powOp.getOperand(0); + rhs = powOp.getOperand(1); + if (lhs == x && constOf(rhs, 3.0)) + return true; + + return reportFailure("subgraph not found"); + } + + static bool reportFailure(std::string msg) { + // Can disable line below if not needed. + LLVM_DEBUG(llvm::dbgs() << "Gelu failure: " << msg << "\n"); + return false; + } +}; + struct RecomposeQLinearMatMulFromQuantizeLinearPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite( ONNXQuantizeLinearOp qlOp, PatternRewriter &rewriter) const final { using namespace onnx_mlir; @@ -442,8 +645,15 @@ void RecomposeONNXToONNXPass::runOnOperation() { FloatAttr epsilon; int64_t axis; bool isRMSLayerNorm; - return !RecomposeLayerNormFromMulPattern::matchLayerNormPattern( - op, x, scale, axis, epsilon, isRMSLayerNorm); + if (RecomposeLayerNormFromMulPattern::matchLayerNormPattern( + op, x, scale, axis, epsilon, isRMSLayerNorm)) + return false; + + bool isExactGelu; + if (RecomposeGeluFromMulPattern::matchGeluPattern(op, x, isExactGelu)) + return false; + + return true; }); // Recompose QLinearMatMul, starting from QuantizeLinear. @@ -469,6 +679,7 @@ void RecomposeONNXToONNXPass::runOnOperation() { void onnx_mlir::getRecomposeONNXToONNXPatterns( mlir::RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); + patterns.insert(context); patterns.insert(context); patterns.insert(context); } diff --git a/test/mlir/onnx/onnx_recompose.mlir b/test/mlir/onnx/onnx_recompose.mlir index e79d4029c4..99f166111f 100644 --- a/test/mlir/onnx/onnx_recompose.mlir +++ b/test/mlir/onnx/onnx_recompose.mlir @@ -261,3 +261,138 @@ func.func @qlinear_matmul(%arg0: tensor, %arg1: tensor, %arg2: // CHECK: return [[VAR_0_]] : tensor // CHECK: } } + +// ----- + +// gelu(x) = [x * (erf(x/1.41421354) + 1)] * 0.5 +func.func @test_gelu_erf_cst_1(%arg0 : tensor) -> tensor{ + %sqrt2 = onnx.Constant dense<1.41421354> : tensor + %one = onnx.Constant dense<1.000000e+00> : tensor + %half = onnx.Constant dense<5.000000e-01> : tensor + %0 = "onnx.Div"(%arg0, %sqrt2) : (tensor, tensor) -> tensor + %1 = "onnx.Erf"(%0) : (tensor) -> tensor + %2 = "onnx.Add"(%1, %one) : (tensor, tensor) -> tensor + %3 = "onnx.Mul"(%arg0, %2) : (tensor, tensor) -> tensor + %4 = "onnx.Mul"(%3, %half) : (tensor, tensor) -> tensor + "func.return"(%4) : (tensor) -> () + +// CHECK-LABEL: func.func @test_gelu_erf_cst_1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.Gelu"([[PARAM_0_]]) {approximate = "none"} : (tensor) -> tensor +// CHECK: return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + +// gelu(x) = [x * (1 + erf(x/1.41421354))] * 0.5 +func.func @test_gelu_erf_cst_change_add_operand_order(%arg0 : tensor) -> tensor{ + %sqrt2 = onnx.Constant dense<1.41421354> : tensor + %one = onnx.Constant dense<1.000000e+00> : tensor + %half = onnx.Constant dense<5.000000e-01> : tensor + %0 = "onnx.Div"(%arg0, %sqrt2) : (tensor, tensor) -> tensor + %1 = "onnx.Erf"(%0) : (tensor) -> tensor + %2 = "onnx.Add"(%one, %1) : (tensor, tensor) -> tensor + %3 = "onnx.Mul"(%arg0, %2) : (tensor, tensor) -> tensor + %4 = "onnx.Mul"(%3, %half) : (tensor, tensor) -> tensor + "func.return"(%4) : (tensor) -> () + +// CHECK-LABEL: func.func @test_gelu_erf_cst_change_add_operand_order +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.Gelu"([[PARAM_0_]]) {approximate = "none"} : (tensor) -> tensor +// CHECK: return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + +// gelu(x) = [(erf(x/1.41421354) + 1) * x] * 0.5 +func.func @test_gelu_erf_cst_change_mul_operand_order_1(%arg0 : tensor) -> tensor{ + %sqrt2 = onnx.Constant dense<1.41421354> : tensor + %one = onnx.Constant dense<1.000000e+00> : tensor + %half = onnx.Constant dense<5.000000e-01> : tensor + %0 = "onnx.Div"(%arg0, %sqrt2) : (tensor, tensor) -> tensor + %1 = "onnx.Erf"(%0) : (tensor) -> tensor + %2 = "onnx.Add"(%1, %one) : (tensor, tensor) -> tensor + %3 = "onnx.Mul"(%2, %arg0) : (tensor, tensor) -> tensor + %4 = "onnx.Mul"(%3, %half) : (tensor, tensor) -> tensor + "func.return"(%4) : (tensor) -> () + +// CHECK-LABEL: func.func @test_gelu_erf_cst_change_mul_operand_order_1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.Gelu"([[PARAM_0_]]) {approximate = "none"} : (tensor) -> tensor +// CHECK: return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + +// gelu(x) = 0.5 * [x * (erf(x/1.41421354) + 1) * x] +func.func @test_gelu_erf_cst_change_mul_operand_order_2(%arg0 : tensor) -> tensor{ + %sqrt2 = onnx.Constant dense<1.41421354> : tensor + %one = onnx.Constant dense<1.000000e+00> : tensor + %half = onnx.Constant dense<5.000000e-01> : tensor + %0 = "onnx.Div"(%arg0, %sqrt2) : (tensor, tensor) -> tensor + %1 = "onnx.Erf"(%0) : (tensor) -> tensor + %2 = "onnx.Add"(%1, %one) : (tensor, tensor) -> tensor + %3 = "onnx.Mul"(%arg0, %2) : (tensor, tensor) -> tensor + %4 = "onnx.Mul"(%half, %3) : (tensor, tensor) -> tensor + "func.return"(%4) : (tensor) -> () + +// CHECK-LABEL: func.func @test_gelu_erf_cst_change_mul_operand_order_2 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.Gelu"([[PARAM_0_]]) {approximate = "none"} : (tensor) -> tensor +// CHECK: return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + +// gelu(x) = x * (0.5 * (1 + tanh[0.797884583 * (x + 0.044715 * x^3)])) +func.func @test_gelu_tanh(%arg0 : tensor<*xf32>) -> tensor<*xf32> { + %one = onnx.Constant dense<1.000000e+00> : tensor + %three = onnx.Constant dense<3.000000e+00> : tensor + %half = onnx.Constant dense<5.000000e-01> : tensor + %sqrt2pi = onnx.Constant dense<0.797884583> : tensor + %cst044715 = onnx.Constant dense<4.471500e-02> : tensor + %0 = "onnx.Pow"(%arg0, %three) : (tensor<*xf32>, tensor) -> tensor<*xf32> + %1 = "onnx.Mul"(%cst044715, %0) : (tensor, tensor<*xf32>) -> tensor<*xf32> + %2 = "onnx.Add"(%arg0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %3 = "onnx.Mul"(%sqrt2pi, %2) : (tensor, tensor<*xf32>) -> tensor<*xf32> + %4 = "onnx.Tanh"(%3) : (tensor<*xf32>) -> tensor<*xf32> + %5 = "onnx.Add"(%one, %4) : (tensor, tensor<*xf32>) -> tensor<*xf32> + %6 = "onnx.Mul"(%half, %5) : (tensor, tensor<*xf32>) -> tensor<*xf32> + %7 = "onnx.Mul"(%arg0, %6) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %7 : tensor<*xf32> + +// CHECK-LABEL: func.func @test_gelu_tanh +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.Gelu"([[PARAM_0_]]) {approximate = "tanh"} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_gelu_erf_two_adds(%arg0: tensor, %arg1: tensor<3072x768xf32>) -> tensor { + %0 = onnx.Constant dense<5.000000e-01> : tensor + %1 = onnx.Constant dense<1.000000e+00> : tensor + %2 = onnx.Constant dense<1.41421354> : tensor + %3 = onnx.Constant dense<3.000000e-01> : tensor<3072xf32> + %4 = "onnx.Add"(%arg0, %3) : (tensor, tensor<3072xf32>) -> tensor + %5 = "onnx.Div"(%4, %2) : (tensor, tensor) -> tensor + %6 = "onnx.Erf"(%5) : (tensor) -> tensor + %7 = "onnx.Add"(%6, %1) : (tensor, tensor) -> tensor + %8 = "onnx.Mul"(%4, %7) : (tensor, tensor) -> tensor + %9 = "onnx.Mul"(%8, %0) : (tensor, tensor) -> tensor + %10 = "onnx.MatMul"(%9, %arg1) : (tensor, tensor<3072x768xf32>) -> tensor + return %10 : tensor +} +// CHECK-LABEL: func.func @test_gelu_erf_two_adds +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<3072x768xf32>) -> tensor { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<3.000000e-01> : tensor<3072xf32> +// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[VAR_0_]]) : (tensor, tensor<3072xf32>) -> tensor +// CHECK: [[VAR_2_:%.+]] = "onnx.Gelu"([[VAR_1_]]) {approximate = "none"} : (tensor) -> tensor +// CHECK: [[VAR_3_:%.+]] = "onnx.MatMul"([[VAR_2_]], [[PARAM_1_]]) : (tensor, tensor<3072x768xf32>) -> tensor +// CHECK: return [[VAR_3_]] : tensor +// CHECK: }