Skip to content

Commit

Permalink
[InstCombine] Transform high latency, dependent FSQRT/FDIV into FMUL (#…
Browse files Browse the repository at this point in the history
…87474)

The proposed patch, in general, tries to transform the below code
sequence:
x = 1.0 / sqrt (a);
r1 = x * x;  // same as 1.0 / a
r2 = a / sqrt(a); // same as sqrt (a)

TO

(If x, r1 and r2 are all used further in the code) 
r1 = 1.0 / a
r2 = sqrt (a)
x = r1 * r2

The transform tries to make high latency sqrt and div operations
independent and also saves on one multiplication.

The patch was tested with SPEC17 suite with cpu=neoverse-v2. The
performance uplift achieved was:
544.nab_r   ~4%

No other regressions were observed. Also, no compile time differences
were observed with the patch.

Closes #54652
  • Loading branch information
sushgokh authored Jan 17, 2025
1 parent 263fed7 commit 7253c6f
Show file tree
Hide file tree
Showing 2 changed files with 807 additions and 0 deletions.
176 changes: 176 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "InstCombineInternal.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/ValueTracking.h"
Expand Down Expand Up @@ -657,6 +658,94 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
return nullptr;
}

// If we have the following pattern,
// X = 1.0/sqrt(a)
// R1 = X * X
// R2 = a/sqrt(a)
// then this method collects all the instructions that match R1 and R2.
static bool getFSqrtDivOptPattern(Instruction *Div,
SmallPtrSetImpl<Instruction *> &R1,
SmallPtrSetImpl<Instruction *> &R2) {
Value *A;
if (match(Div, m_FDiv(m_FPOne(), m_Sqrt(m_Value(A)))) ||
match(Div, m_FDiv(m_SpecificFP(-1.0), m_Sqrt(m_Value(A))))) {
for (User *U : Div->users()) {
Instruction *I = cast<Instruction>(U);
if (match(I, m_FMul(m_Specific(Div), m_Specific(Div))))
R1.insert(I);
}

CallInst *CI = cast<CallInst>(Div->getOperand(1));
for (User *U : CI->users()) {
Instruction *I = cast<Instruction>(U);
if (match(I, m_FDiv(m_Specific(A), m_Sqrt(m_Specific(A)))))
R2.insert(I);
}
}
return !R1.empty() && !R2.empty();
}

// Check legality for transforming
// x = 1.0/sqrt(a)
// r1 = x * x;
// r2 = a/sqrt(a);
//
// TO
//
// r1 = 1/a
// r2 = sqrt(a)
// x = r1 * r2
// This transform works only when 'a' is known positive.
static bool isFSqrtDivToFMulLegal(Instruction *X,
SmallPtrSetImpl<Instruction *> &R1,
SmallPtrSetImpl<Instruction *> &R2) {
// Check if the required pattern for the transformation exists.
if (!getFSqrtDivOptPattern(X, R1, R2))
return false;

BasicBlock *BBx = X->getParent();
BasicBlock *BBr1 = (*R1.begin())->getParent();
BasicBlock *BBr2 = (*R2.begin())->getParent();

CallInst *FSqrt = cast<CallInst>(X->getOperand(1));
if (!FSqrt->hasAllowReassoc() || !FSqrt->hasNoNaNs() ||
!FSqrt->hasNoSignedZeros() || !FSqrt->hasNoInfs())
return false;

// We change x = 1/sqrt(a) to x = sqrt(a) * 1/a . This change isn't allowed
// by recip fp as it is strictly meant to transform ops of type a/b to
// a * 1/b. So, this can be considered as algebraic rewrite and reassoc flag
// has been used(rather abused)in the past for algebraic rewrites.
if (!X->hasAllowReassoc() || !X->hasAllowReciprocal() || !X->hasNoInfs())
return false;

// Check the constraints on X, R1 and R2 combined.
// fdiv instruction and one of the multiplications must reside in the same
// block. If not, the optimized code may execute more ops than before and
// this may hamper the performance.
if (BBx != BBr1 && BBx != BBr2)
return false;

// Check the constraints on instructions in R1.
if (any_of(R1, [BBr1](Instruction *I) {
// When you have multiple instructions residing in R1 and R2
// respectively, it's difficult to generate combinations of (R1,R2) and
// then check if we have the required pattern. So, for now, just be
// conservative.
return (I->getParent() != BBr1 || !I->hasAllowReassoc());
}))
return false;

// Check the constraints on instructions in R2.
return all_of(R2, [BBr2](Instruction *I) {
// When you have multiple instructions residing in R1 and R2
// respectively, it's difficult to generate combination of (R1,R2) and
// then check if we have the required pattern. So, for now, just be
// conservative.
return (I->getParent() == BBr2 && I->hasAllowReassoc());
});
}

Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
Value *Op0 = I.getOperand(0);
Value *Op1 = I.getOperand(1);
Expand Down Expand Up @@ -1913,6 +2002,75 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
}

// Change
// X = 1/sqrt(a)
// R1 = X * X
// R2 = a * X
//
// TO
//
// FDiv = 1/a
// FSqrt = sqrt(a)
// FMul = FDiv * FSqrt
// Replace Uses Of R1 With FDiv
// Replace Uses Of R2 With FSqrt
// Replace Uses Of X With FMul
static Instruction *
convertFSqrtDivIntoFMul(CallInst *CI, Instruction *X,
const SmallPtrSetImpl<Instruction *> &R1,
const SmallPtrSetImpl<Instruction *> &R2,
InstCombiner::BuilderTy &B, InstCombinerImpl *IC) {

B.SetInsertPoint(X);

// Have an instruction that is representative of all of instructions in R1 and
// get the most common fpmath metadata and fast-math flags on it.
Value *SqrtOp = CI->getArgOperand(0);
auto *FDiv = cast<Instruction>(
B.CreateFDiv(ConstantFP::get(X->getType(), 1.0), SqrtOp));
auto *R1FPMathMDNode = (*R1.begin())->getMetadata(LLVMContext::MD_fpmath);
FastMathFlags R1FMF = (*R1.begin())->getFastMathFlags(); // Common FMF
for (Instruction *I : R1) {
R1FPMathMDNode = MDNode::getMostGenericFPMath(
R1FPMathMDNode, I->getMetadata(LLVMContext::MD_fpmath));
R1FMF &= I->getFastMathFlags();
IC->replaceInstUsesWith(*I, FDiv);
IC->eraseInstFromFunction(*I);
}
FDiv->setMetadata(LLVMContext::MD_fpmath, R1FPMathMDNode);
FDiv->copyFastMathFlags(R1FMF);

// Have a single sqrt call instruction that is representative of all of
// instructions in R2 and get the most common fpmath metadata and fast-math
// flags on it.
auto *FSqrt = cast<CallInst>(CI->clone());
FSqrt->insertBefore(CI);
auto *R2FPMathMDNode = (*R2.begin())->getMetadata(LLVMContext::MD_fpmath);
FastMathFlags R2FMF = (*R2.begin())->getFastMathFlags(); // Common FMF
for (Instruction *I : R2) {
R2FPMathMDNode = MDNode::getMostGenericFPMath(
R2FPMathMDNode, I->getMetadata(LLVMContext::MD_fpmath));
R2FMF &= I->getFastMathFlags();
IC->replaceInstUsesWith(*I, FSqrt);
IC->eraseInstFromFunction(*I);
}
FSqrt->setMetadata(LLVMContext::MD_fpmath, R2FPMathMDNode);
FSqrt->copyFastMathFlags(R2FMF);

Instruction *FMul;
// If X = -1/sqrt(a) initially,then FMul = -(FDiv * FSqrt)
if (match(X, m_FDiv(m_SpecificFP(-1.0), m_Specific(CI)))) {
Value *Mul = B.CreateFMul(FDiv, FSqrt);
FMul = cast<Instruction>(B.CreateFNeg(Mul));
} else
FMul = cast<Instruction>(B.CreateFMul(FDiv, FSqrt));
FMul->copyMetadata(*X);
FMul->copyFastMathFlags(FastMathFlags::intersectRewrite(R1FMF, R2FMF) |
FastMathFlags::unionValue(R1FMF, R2FMF));
IC->replaceInstUsesWith(*X, FMul);
return IC->eraseInstFromFunction(*X);
}

Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
Module *M = I.getModule();

Expand All @@ -1937,6 +2095,24 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
return R;

Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);

// Convert
// x = 1.0/sqrt(a)
// r1 = x * x;
// r2 = a/sqrt(a);
//
// TO
//
// r1 = 1/a
// r2 = sqrt(a)
// x = r1 * r2
SmallPtrSet<Instruction *, 2> R1, R2;
if (isFSqrtDivToFMulLegal(&I, R1, R2)) {
CallInst *CI = cast<CallInst>(I.getOperand(1));
if (Instruction *D = convertFSqrtDivIntoFMul(CI, &I, R1, R2, Builder, this))
return D;
}

if (isa<Constant>(Op0))
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
if (Instruction *R = FoldOpIntoSelect(I, SI))
Expand Down
Loading

0 comments on commit 7253c6f

Please sign in to comment.