Skip to content

Commit

Permalink
Lower GetElement on arm64 to the correct access sequence (#104288)
Browse files Browse the repository at this point in the history
* Lower GetElement on arm64 to the correct access sequence

* Use constant offset where possible

* Ensure that lvaSIMDInitTempVarNum is marked as being used by LclAddrNode

* Fix assert

* Create a valid addr mode for Arm64

* Don't lower unnecessarily

* Account for index 0 and scale 1

* Remove the offset constant node when it's unused
  • Loading branch information
tannergooding authored Jul 5, 2024
1 parent 06e0076 commit b9673cb
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 133 deletions.
105 changes: 11 additions & 94 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,10 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
case NI_Vector128_GetElement:
{
assert(intrin.numOperands == 2);
assert(!intrin.op1->isContained());

assert(intrin.op2->OperIsConst());
assert(intrin.op2->isContained());

var_types simdType = Compiler::getSIMDTypeForSize(node->GetSimdSize());

Expand All @@ -1663,109 +1667,22 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
simdType = TYP_SIMD16;
}

if (!intrin.op2->OperIsConst())
{
assert(!intrin.op2->isContained());

emitAttr baseTypeSize = emitTypeSize(intrin.baseType);
unsigned baseTypeScale = genLog2(EA_SIZE_IN_BYTES(baseTypeSize));

regNumber baseReg;
regNumber indexReg = op2Reg;

// Optimize the case of op1 is in memory and trying to access i'th element.
if (!intrin.op1->isUsedFromReg())
{
assert(intrin.op1->isContained());

if (intrin.op1->OperIsLocal())
{
unsigned varNum = intrin.op1->AsLclVarCommon()->GetLclNum();
baseReg = internalRegisters.Extract(node);

// Load the address of varNum
GetEmitter()->emitIns_R_S(INS_lea, EA_PTRSIZE, baseReg, varNum, 0);
}
else
{
// Require GT_IND addr to be not contained.
assert(intrin.op1->OperIs(GT_IND));

GenTree* addr = intrin.op1->AsIndir()->Addr();
assert(!addr->isContained());
baseReg = addr->GetRegNum();
}
}
else
{
unsigned simdInitTempVarNum = compiler->lvaSIMDInitTempVarNum;
noway_assert(simdInitTempVarNum != BAD_VAR_NUM);

baseReg = internalRegisters.Extract(node);

// Load the address of simdInitTempVarNum
GetEmitter()->emitIns_R_S(INS_lea, EA_PTRSIZE, baseReg, simdInitTempVarNum, 0);

// Store the vector to simdInitTempVarNum
GetEmitter()->emitIns_R_R(INS_str, emitTypeSize(simdType), op1Reg, baseReg);
}

assert(genIsValidIntReg(indexReg));
assert(genIsValidIntReg(baseReg));
assert(baseReg != indexReg);
ssize_t ival = intrin.op2->AsIntCon()->IconValue();

// Load item at baseReg[index]
GetEmitter()->emitIns_R_R_R_Ext(ins_Load(intrin.baseType), baseTypeSize, targetReg, baseReg,
indexReg, INS_OPTS_LSL, baseTypeScale);
}
else if (!GetEmitter()->isValidVectorIndex(emitTypeSize(simdType), emitTypeSize(intrin.baseType),
intrin.op2->AsIntCon()->IconValue()))
if (!GetEmitter()->isValidVectorIndex(emitTypeSize(simdType), emitTypeSize(intrin.baseType), ival))
{
// We only need to generate code for the get if the index is valid
// If the index is invalid, previously generated for the range check will throw
break;
}
else if (!intrin.op1->isUsedFromReg())
{
assert(intrin.op1->isContained());
assert(intrin.op2->IsCnsIntOrI());

int offset = (int)intrin.op2->AsIntCon()->IconValue() * genTypeSize(intrin.baseType);
instruction ins = ins_Load(intrin.baseType);

assert(!intrin.op1->isUsedFromReg());

if (intrin.op1->OperIsLocal())
{
unsigned varNum = intrin.op1->AsLclVarCommon()->GetLclNum();
GetEmitter()->emitIns_R_S(ins, emitActualTypeSize(intrin.baseType), targetReg, varNum, offset);
}
else
{
assert(intrin.op1->OperIs(GT_IND));

GenTree* addr = intrin.op1->AsIndir()->Addr();
assert(!addr->isContained());
regNumber baseReg = addr->GetRegNum();

// ldr targetReg, [baseReg, #offset]
GetEmitter()->emitIns_R_R_I(ins, emitActualTypeSize(intrin.baseType), targetReg, baseReg,
offset);
}
}
else
if ((varTypeIsFloating(intrin.baseType) && (targetReg == op1Reg) && (ival == 0)))
{
assert(intrin.op2->IsCnsIntOrI());
ssize_t indexValue = intrin.op2->AsIntCon()->IconValue();

// no-op if vector is float/double, targetReg == op1Reg and fetching for 0th index.
if ((varTypeIsFloating(intrin.baseType) && (targetReg == op1Reg) && (indexValue == 0)))
{
break;
}

GetEmitter()->emitIns_R_R_I(ins, emitTypeSize(intrin.baseType), targetReg, op1Reg, indexValue,
INS_OPTS_NONE);
break;
}

GetEmitter()->emitIns_R_R_I(ins, emitTypeSize(intrin.baseType), targetReg, op1Reg, ival, INS_OPTS_NONE);
break;
}

Expand Down
143 changes: 126 additions & 17 deletions src/coreclr/jit/lowerarmarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,128 @@ GenTree* Lowering::LowerHWIntrinsic(GenTreeHWIntrinsic* node)
return LowerHWIntrinsicDot(node);
}

case NI_Vector64_GetElement:
case NI_Vector128_GetElement:
{
GenTree* op1 = node->Op(1);
GenTree* op2 = node->Op(2);

bool isContainableMemory = IsContainableMemoryOp(op1) && IsSafeToContainMem(node, op1);

if (isContainableMemory || !op2->OperIsConst())
{
unsigned simdSize = node->GetSimdSize();
var_types simdBaseType = node->GetSimdBaseType();
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);

// We're either already loading from memory or we need to since
// we don't know what actual index is going to be retrieved.

unsigned lclNum = BAD_VAR_NUM;
unsigned lclOffs = 0;

if (!isContainableMemory)
{
// We aren't already in memory, so we need to spill there

comp->getSIMDInitTempVarNum(simdType);
lclNum = comp->lvaSIMDInitTempVarNum;

GenTree* storeLclVar = comp->gtNewStoreLclVarNode(lclNum, op1);
BlockRange().InsertBefore(node, storeLclVar);
LowerNode(storeLclVar);
}
else if (op1->IsLocal())
{
// We're an existing local that is loaded from memory
GenTreeLclVarCommon* lclVar = op1->AsLclVarCommon();

lclNum = lclVar->GetLclNum();
lclOffs = lclVar->GetLclOffs();

BlockRange().Remove(op1);
}

if (lclNum != BAD_VAR_NUM)
{
// We need to get the address of the local
op1 = comp->gtNewLclAddrNode(lclNum, lclOffs, TYP_BYREF);
BlockRange().InsertBefore(node, op1);
LowerNode(op1);
}
else
{
assert(op1->isIndir());

// We need to get the underlying address
GenTree* addr = op1->AsIndir()->Addr();
BlockRange().Remove(op1);
op1 = addr;
}

GenTree* offset = op2;
unsigned baseTypeSize = genTypeSize(simdBaseType);

if (offset->OperIsConst())
{
// We have a constant index, so scale it up directly
GenTreeIntConCommon* index = offset->AsIntCon();
index->SetIconValue(index->IconValue() * baseTypeSize);
}
else
{
// We have a non-constant index, so scale it up via mul but
// don't lower the GT_MUL node since the indir will try to
// create an addressing mode and will do folding itself. We
// do, however, skip the multiply for scale == 1

if (baseTypeSize != 1)
{
GenTreeIntConCommon* scale = comp->gtNewIconNode(baseTypeSize);
BlockRange().InsertBefore(node, scale);

offset = comp->gtNewOperNode(GT_MUL, offset->TypeGet(), offset, scale);
BlockRange().InsertBefore(node, offset);
}
}

// Add the offset, don't lower the GT_ADD node since the indir will
// try to create an addressing mode and will do folding itself. We
// do, however, skip the add for offset == 0
GenTree* addr = op1;

if (!offset->IsIntegralConst(0))
{
addr = comp->gtNewOperNode(GT_ADD, addr->TypeGet(), addr, offset);
BlockRange().InsertBefore(node, addr);
}
else
{
BlockRange().Remove(offset);
}

// Finally we can indirect the memory address to get the actual value
GenTreeIndir* indir = comp->gtNewIndir(simdBaseType, addr);
BlockRange().InsertBefore(node, indir);

LIR::Use use;
if (BlockRange().TryGetUse(node, &use))
{
use.ReplaceWith(indir);
}
else
{
indir->SetUnusedValue();
}

BlockRange().Remove(node);
return LowerNode(indir);
}

assert(op2->OperIsConst());
break;
}

case NI_Vector64_op_Equality:
case NI_Vector128_op_Equality:
{
Expand Down Expand Up @@ -3318,24 +3440,11 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
case NI_Vector64_GetElement:
case NI_Vector128_GetElement:
{
assert(varTypeIsIntegral(intrin.op2));
assert(!IsContainableMemoryOp(intrin.op1) || !IsSafeToContainMem(node, intrin.op1));
assert(intrin.op2->OperIsConst());

if (intrin.op2->IsCnsIntOrI())
{
MakeSrcContained(node, intrin.op2);
}

// TODO: Codegen isn't currently handling this correctly
//
// if (IsContainableMemoryOp(intrin.op1) && IsSafeToContainMem(node, intrin.op1))
// {
// MakeSrcContained(node, intrin.op1);
//
// if (intrin.op1->OperIs(GT_IND))
// {
// intrin.op1->AsIndir()->Addr()->ClearContained();
// }
// }
// Loading a constant index from register
MakeSrcContained(node, intrin.op2);
break;
}

Expand Down
23 changes: 1 addition & 22 deletions src/coreclr/jit/lsraarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1944,7 +1944,6 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
srcCount += BuildDelayFreeUses(intrin.op3, embOp2Node->Op(1));
}
}

else if (intrin.op2 != nullptr)
{
// RMW intrinsic operands doesn't have to be delayFree when they can be assigned the same register as op1Reg
Expand All @@ -1955,28 +1954,8 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
bool forceOp2DelayFree = false;
SingleTypeRegSet lowVectorCandidates = RBM_NONE;
size_t lowVectorOperandNum = 0;
if ((intrin.id == NI_Vector64_GetElement) || (intrin.id == NI_Vector128_GetElement))
{
if (!intrin.op2->IsCnsIntOrI() && (!intrin.op1->isContained() || intrin.op1->OperIsLocal()))
{
// If the index is not a constant and the object is not contained or is a local
// we will need a general purpose register to calculate the address
// internal register must not clobber input index
// TODO-Cleanup: An internal register will never clobber a source; this code actually
// ensures that the index (op2) doesn't interfere with the target.
buildInternalIntRegisterDefForNode(intrinsicTree);
forceOp2DelayFree = true;
}

if (!intrin.op2->IsCnsIntOrI() && !intrin.op1->isContained())
{
// If the index is not a constant or op1 is in register,
// we will use the SIMD temp location to store the vector.
var_types requiredSimdTempType = (intrin.id == NI_Vector64_GetElement) ? TYP_SIMD8 : TYP_SIMD16;
compiler->getSIMDInitTempVarNum(requiredSimdTempType);
}
}
else if (HWIntrinsicInfo::IsLowVectorOperation(intrin.id))
if (HWIntrinsicInfo::IsLowVectorOperation(intrin.id))
{
getLowVectorOperandAndCandidates(intrin, &lowVectorOperandNum, &lowVectorCandidates);
}
Expand Down

0 comments on commit b9673cb

Please sign in to comment.