Skip to content

Commit

Permalink
[AArch64] Eliminate Common Subexpression of CSEL by Reassociation (ll…
Browse files Browse the repository at this point in the history
…vm#121350)

If we have a CSEL instruction that depends on the flags set by a
(SUBS x c) instruction and the true and/or false expression is
(add (add x y) -c), we can reassociate the latter expression to
(add (SUBS x c) y) and save one instruction.

Proof for the basic transformation: https://alive2.llvm.org/ce/z/-337Pb

We can extend this transformation for slightly different constants. For
example, if we have (add (add x y) -(c-1)) and a the comparison x <u c,
we can transform the comparison to x <=u c-1 to eliminate the comparison
instruction, too. Similarly, we can transform (x == 0) to (x <u 1).

Proofs for the transformations that alter the constants:
https://alive2.llvm.org/ce/z/3nVqgR

Fixes llvm#119606.
  • Loading branch information
mskamp authored and BaiXilin committed Jan 12, 2025
1 parent 6b6f370 commit b792567
Show file tree
Hide file tree
Showing 2 changed files with 894 additions and 0 deletions.
121 changes: 121 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24867,6 +24867,122 @@ static SDValue foldCSELOfCSEL(SDNode *Op, SelectionDAG &DAG) {
return DAG.getNode(AArch64ISD::CSEL, DL, VT, L, R, CCValue, Cond);
}

// Reassociate the true/false expressions of a CSEL instruction to obtain a
// common subexpression with the comparison instruction. For example, change
// (CSEL (ADD (ADD x y) -c) f LO (SUBS x c)) to
// (CSEL (ADD (SUBS x c) y) f LO (SUBS x c)) such that (SUBS x c) is a common
// subexpression.
static SDValue reassociateCSELOperandsForCSE(SDNode *N, SelectionDAG &DAG) {
SDValue SubsNode = N->getOperand(3);
if (SubsNode.getOpcode() != AArch64ISD::SUBS || !SubsNode.hasOneUse())
return SDValue();
auto *CmpOpConst = dyn_cast<ConstantSDNode>(SubsNode.getOperand(1));
if (!CmpOpConst)
return SDValue();

SDValue CmpOpOther = SubsNode.getOperand(0);
EVT VT = N->getValueType(0);

// Get the operand that can be reassociated with the SUBS instruction.
auto GetReassociationOp = [&](SDValue Op, APInt ExpectedConst) {
if (Op.getOpcode() != ISD::ADD)
return SDValue();
if (Op.getOperand(0).getOpcode() != ISD::ADD ||
!Op.getOperand(0).hasOneUse())
return SDValue();
SDValue X = Op.getOperand(0).getOperand(0);
SDValue Y = Op.getOperand(0).getOperand(1);
if (X != CmpOpOther)
std::swap(X, Y);
if (X != CmpOpOther)
return SDValue();
auto *AddOpConst = dyn_cast<ConstantSDNode>(Op.getOperand(1));
if (!AddOpConst || AddOpConst->getAPIntValue() != ExpectedConst)
return SDValue();
return Y;
};

// Try the reassociation using the given constant and condition code.
auto Fold = [&](APInt NewCmpConst, AArch64CC::CondCode NewCC) {
APInt ExpectedConst = -NewCmpConst;
SDValue TReassocOp = GetReassociationOp(N->getOperand(0), ExpectedConst);
SDValue FReassocOp = GetReassociationOp(N->getOperand(1), ExpectedConst);
if (!TReassocOp && !FReassocOp)
return SDValue();

SDValue NewCmp = DAG.getNode(AArch64ISD::SUBS, SDLoc(SubsNode),
DAG.getVTList(VT, MVT_CC), CmpOpOther,
DAG.getConstant(NewCmpConst, SDLoc(CmpOpConst),
CmpOpConst->getValueType(0)));

auto Reassociate = [&](SDValue ReassocOp, unsigned OpNum) {
if (!ReassocOp)
return N->getOperand(OpNum);
SDValue Res = DAG.getNode(ISD::ADD, SDLoc(N->getOperand(OpNum)), VT,
NewCmp.getValue(0), ReassocOp);
DAG.ReplaceAllUsesWith(N->getOperand(OpNum), Res);
return Res;
};

SDValue TValReassoc = Reassociate(TReassocOp, 0);
SDValue FValReassoc = Reassociate(FReassocOp, 1);
return DAG.getNode(AArch64ISD::CSEL, SDLoc(N), VT, TValReassoc, FValReassoc,
DAG.getConstant(NewCC, SDLoc(N->getOperand(2)), MVT_CC),
NewCmp.getValue(1));
};

auto CC = static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2));

// First, try to eliminate the compare instruction by searching for a
// subtraction with the same constant.
if (SDValue R = Fold(CmpOpConst->getAPIntValue(), CC))
return R;

if ((CC == AArch64CC::EQ || CC == AArch64CC::NE) && !CmpOpConst->isZero())
return SDValue();

// Next, search for a subtraction with a slightly different constant. By
// adjusting the condition code, we can still eliminate the compare
// instruction. Adjusting the constant is only valid if it does not result
// in signed/unsigned wrap for signed/unsigned comparisons, respectively.
// Since such comparisons are trivially true/false, we should not encounter
// them here but check for them nevertheless to be on the safe side.
auto CheckedFold = [&](bool Check, APInt NewCmpConst,
AArch64CC::CondCode NewCC) {
return Check ? Fold(NewCmpConst, NewCC) : SDValue();
};
switch (CC) {
case AArch64CC::EQ:
case AArch64CC::LS:
return CheckedFold(!CmpOpConst->getAPIntValue().isMaxValue(),
CmpOpConst->getAPIntValue() + 1, AArch64CC::LO);
case AArch64CC::NE:
case AArch64CC::HI:
return CheckedFold(!CmpOpConst->getAPIntValue().isMaxValue(),
CmpOpConst->getAPIntValue() + 1, AArch64CC::HS);
case AArch64CC::LO:
return CheckedFold(!CmpOpConst->getAPIntValue().isZero(),
CmpOpConst->getAPIntValue() - 1, AArch64CC::LS);
case AArch64CC::HS:
return CheckedFold(!CmpOpConst->getAPIntValue().isZero(),
CmpOpConst->getAPIntValue() - 1, AArch64CC::HI);
case AArch64CC::LT:
return CheckedFold(!CmpOpConst->getAPIntValue().isMinSignedValue(),
CmpOpConst->getAPIntValue() - 1, AArch64CC::LE);
case AArch64CC::LE:
return CheckedFold(!CmpOpConst->getAPIntValue().isMaxSignedValue(),
CmpOpConst->getAPIntValue() + 1, AArch64CC::LT);
case AArch64CC::GT:
return CheckedFold(!CmpOpConst->getAPIntValue().isMaxSignedValue(),
CmpOpConst->getAPIntValue() + 1, AArch64CC::GE);
case AArch64CC::GE:
return CheckedFold(!CmpOpConst->getAPIntValue().isMinSignedValue(),
CmpOpConst->getAPIntValue() - 1, AArch64CC::GT);
default:
return SDValue();
}
}

// Optimize CSEL instructions
static SDValue performCSELCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
Expand All @@ -24878,6 +24994,11 @@ static SDValue performCSELCombine(SDNode *N,
if (SDValue R = foldCSELOfCSEL(N, DAG))
return R;

// Try to reassociate the true/false expressions so that we can do CSE with
// a SUBS instruction used to perform the comparison.
if (SDValue R = reassociateCSELOperandsForCSE(N, DAG))
return R;

// CSEL 0, cttz(X), eq(X, 0) -> AND cttz bitwidth-1
// CSEL cttz(X), 0, ne(X, 0) -> AND cttz bitwidth-1
if (SDValue Folded = foldCSELofCTTZ(N, DAG))
Expand Down
Loading

0 comments on commit b792567

Please sign in to comment.