Skip to content

Commit

Permalink
[MLIR][TORCH] Add support for tanh approximation for Gelu op (llvm#2941)
Browse files Browse the repository at this point in the history
Fixes nod-ai/SHARK-ModelDev#461

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
  • Loading branch information
vivekkhandelwal1 authored Feb 27, 2024
1 parent d81747e commit d628b5f
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 4 deletions.
39 changes: 35 additions & 4 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,11 +511,42 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
// TODO: Take approximation into account.
std::string approximate;
if (!matchPattern(gelu.getApproximate(), m_TorchConstantStr(approximate)) ||
approximate != "none")
if (!matchPattern(gelu.getApproximate(), m_TorchConstantStr(approximate))) {
gelu.emitError(
"unimplemented: expected approximate to be a constant str");
return nullptr;
Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[0]);
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdf);
}
if (approximate == "none") {
Value multiplier = buildUnitNormalCdf(b, loc, payloadArgs[0]);
return b.create<arith::MulFOp>(loc, payloadArgs[0], multiplier);
}
if (approximate == "tanh") {
// GELU(x)=0.5∗x∗(1+Tanh((2/π)^1/2 * (x+0.044715∗x^3)))
// Ref: https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
Value cstThree = b.create<arith::ConstantOp>(
loc, IntegerAttr::get(IntegerType::get(op->getContext(), 64), 3));
Value xCube = b.create<math::FPowIOp>(loc, payloadArgs[0], cstThree);
Type elementType = payloadArgs[0].getType();
Value cstAlpha = b.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.044715));
Value xCubeMulAlpha = b.create<arith::MulFOp>(loc, xCube, cstAlpha);
Value xPlusXCubeMulAlpha =
b.create<arith::AddFOp>(loc, payloadArgs[0], xCubeMulAlpha);
Value cstBeta = b.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.7977240352174656));
Value betaMulX =
b.create<arith::MulFOp>(loc, cstBeta, xPlusXCubeMulAlpha);
Value tanh = b.create<math::TanhOp>(loc, betaMulX);
Value cstOne =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1.0));
Value onePlusTanh = b.create<arith::AddFOp>(loc, cstOne, tanh);
Value cstHalf =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.5));
Value multiplier = b.create<arith::MulFOp>(loc, cstHalf, onePlusTanh);
return b.create<arith::MulFOp>(loc, payloadArgs[0], multiplier);
}
gelu.emitError("unimplemented: approximate value should be none or tanh");
return nullptr;
}
if (auto geluBackward = dyn_cast<AtenGeluBackwardOp>(op)) {
if (!geluBackward.getType()
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@
"ElementwiseFloorIntModule_basic",
"ElementwiseFloorModule_basic",
"ElementwiseGeluModule_basic",
"ElementwiseGeluApproximateTanhModule_basic",
"ElementwiseLeakyReluStaticModule_basic",
"ElementwiseLogModule_basic",
"ElementwiseNanToNumModule_Basic",
Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,29 @@ def ElementwiseGeluModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseGeluApproximateTanhModule(torch.nn.Module):

def __init__(self):
super().__init__()
self.gelu = torch.nn.GELU(approximate="tanh")

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return self.gelu(x)


@register_test_case(module_factory=lambda: ElementwiseGeluApproximateTanhModule())
def ElementwiseGeluApproximateTanhModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-0.5, high=0.5))


# ==============================================================================


class ElementwiseSeluModule(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit d628b5f

Please sign in to comment.