Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIRV] Stop using Register to represent target specific virtual registers. #129362

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include "SPIRV.h"
#include "SPIRVBaseInfo.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/CodeGen/Register.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/MC/MCExpr.h"
#include "llvm/MC/MCInst.h"
Expand Down Expand Up @@ -97,7 +96,7 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
}

void SPIRVInstPrinter::recordOpExtInstImport(const MCInst *MI) {
Register Reg = MI->getOperand(0).getReg();
MCRegister Reg = MI->getOperand(0).getReg();
auto Name = getSPIRVStringOperand(*MI, 1);
auto Set = getExtInstSetFromString(Name);
ExtInstSetIDs.insert({Reg, Set});
Expand Down Expand Up @@ -335,7 +334,7 @@ void SPIRVInstPrinter::printOperand(const MCInst *MI, unsigned OpNo,
if (OpNo < MI->getNumOperands()) {
const MCOperand &Op = MI->getOperand(OpNo);
if (Op.isReg())
O << '%' << (Register(Op.getReg()).virtRegIndex() + 1);
O << '%' << (getIDFromRegister(Op.getReg().id()) + 1);
else if (Op.isImm())
O << formatImm((int64_t)Op.getImm());
else if (Op.isDFPImm())
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
//===----------------------------------------------------------------------===//

#include "MCTargetDesc/SPIRVMCTargetDesc.h"
#include "llvm/CodeGen/Register.h"
#include "llvm/MC/MCCodeEmitter.h"
#include "llvm/MC/MCFixup.h"
#include "llvm/MC/MCInst.h"
Expand Down Expand Up @@ -77,7 +76,8 @@ static void emitOperand(const MCOperand &Op, SmallVectorImpl<char> &CB) {
if (Op.isReg()) {
// Emit the id index starting at 1 (0 is an invalid index).
support::endian::write<uint32_t>(
CB, Register(Op.getReg()).virtRegIndex() + 1, llvm::endianness::little);
CB, SPIRV::getIDFromRegister(Op.getReg().id()) + 1,
llvm::endianness::little);
} else if (Op.isImm()) {
support::endian::write(CB, static_cast<uint32_t>(Op.getImm()),
llvm::endianness::little);
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCTargetDesc.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVMCTARGETDESC_H

#include "llvm/Support/DataTypes.h"
#include <cassert>
#include <memory>

namespace llvm {
Expand Down Expand Up @@ -50,4 +51,11 @@ std::unique_ptr<MCObjectTargetWriter> createSPIRVObjectTargetWriter();
#define GET_SUBTARGETINFO_ENUM
#include "SPIRVGenSubtargetInfo.inc"

namespace llvm::SPIRV {
inline unsigned getIDFromRegister(unsigned Reg) {
assert(Reg & (1U << 31));
return Reg & ~(1U << 31);
}
} // namespace llvm::SPIRV

#endif // LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVMCTARGETDESC_H
22 changes: 11 additions & 11 deletions llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ class SPIRVAsmPrinter : public AsmPrinter {
void outputOpMemoryModel();
void outputOpFunctionEnd();
void outputExtFuncDecls();
void outputExecutionModeFromMDNode(Register Reg, MDNode *Node,
void outputExecutionModeFromMDNode(MCRegister Reg, MDNode *Node,
SPIRV::ExecutionMode::ExecutionMode EM,
unsigned ExpectMDOps, int64_t DefVal);
void outputExecutionModeFromNumthreadsAttribute(
const Register &Reg, const Attribute &Attr,
const MCRegister &Reg, const Attribute &Attr,
SPIRV::ExecutionMode::ExecutionMode EM);
void outputExecutionMode(const Module &M);
void outputAnnotations(const Module &M);
Expand Down Expand Up @@ -316,7 +316,7 @@ void SPIRVAsmPrinter::outputDebugSourceAndStrings(const Module &M) {
void SPIRVAsmPrinter::outputOpExtInstImports(const Module &M) {
for (auto &CU : MAI->ExtInstSetMap) {
unsigned Set = CU.first;
Register Reg = CU.second;
MCRegister Reg = CU.second;
MCInst Inst;
Inst.setOpcode(SPIRV::OpExtInstImport);
Inst.addOperand(MCOperand::createReg(Reg));
Expand All @@ -341,7 +341,7 @@ void SPIRVAsmPrinter::outputOpMemoryModel() {
// the interface of this entry point.
void SPIRVAsmPrinter::outputEntryPoints() {
// Find all OpVariable IDs with required StorageClass.
DenseSet<Register> InterfaceIDs;
DenseSet<MCRegister> InterfaceIDs;
for (const MachineInstr *MI : MAI->GlobalVarList) {
assert(MI->getOpcode() == SPIRV::OpVariable);
auto SC = static_cast<SPIRV::StorageClass::StorageClass>(
Expand All @@ -353,7 +353,7 @@ void SPIRVAsmPrinter::outputEntryPoints() {
if (ST->isAtLeastSPIRVVer(VersionTuple(1, 4)) ||
SC == SPIRV::StorageClass::Input || SC == SPIRV::StorageClass::Output) {
const MachineFunction *MF = MI->getMF();
Register Reg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
MCRegister Reg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
InterfaceIDs.insert(Reg);
}
}
Expand All @@ -363,7 +363,7 @@ void SPIRVAsmPrinter::outputEntryPoints() {
SPIRVMCInstLower MCInstLowering;
MCInst TmpInst;
MCInstLowering.lower(MI, TmpInst, MAI);
for (Register Reg : InterfaceIDs) {
for (MCRegister Reg : InterfaceIDs) {
assert(Reg.isValid());
TmpInst.addOperand(MCOperand::createReg(Reg));
}
Expand Down Expand Up @@ -444,7 +444,7 @@ static void addOpsFromMDNode(MDNode *MDN, MCInst &Inst,
if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
Inst.addOperand(MCOperand::createImm(Const->getZExtValue()));
} else if (auto *CE = dyn_cast<Function>(C)) {
Register FuncReg = MAI->getFuncReg(CE);
MCRegister FuncReg = MAI->getFuncReg(CE);
assert(FuncReg.isValid());
Inst.addOperand(MCOperand::createReg(FuncReg));
}
Expand All @@ -453,7 +453,7 @@ static void addOpsFromMDNode(MDNode *MDN, MCInst &Inst,
}

void SPIRVAsmPrinter::outputExecutionModeFromMDNode(
Register Reg, MDNode *Node, SPIRV::ExecutionMode::ExecutionMode EM,
MCRegister Reg, MDNode *Node, SPIRV::ExecutionMode::ExecutionMode EM,
unsigned ExpectMDOps, int64_t DefVal) {
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
Expand All @@ -470,7 +470,7 @@ void SPIRVAsmPrinter::outputExecutionModeFromMDNode(
}

void SPIRVAsmPrinter::outputExecutionModeFromNumthreadsAttribute(
const Register &Reg, const Attribute &Attr,
const MCRegister &Reg, const Attribute &Attr,
SPIRV::ExecutionMode::ExecutionMode EM) {
assert(Attr.isValid() && "Function called with an invalid attribute.");

Expand Down Expand Up @@ -508,7 +508,7 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
// <Entry Point> operands of OpExecutionMode
if (F.isDeclaration() || !isEntryPoint(F))
continue;
Register FReg = MAI->getFuncReg(&F);
MCRegister FReg = MAI->getFuncReg(&F);
assert(FReg.isValid());
if (MDNode *Node = F.getMetadata("reqd_work_group_size"))
outputExecutionModeFromMDNode(FReg, Node, SPIRV::ExecutionMode::LocalSize,
Expand Down Expand Up @@ -560,7 +560,7 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
if (!isa<Function>(AnnotatedVar))
report_fatal_error("Unsupported value in llvm.global.annotations");
Function *Func = cast<Function>(AnnotatedVar);
Register Reg = MAI->getFuncReg(Func);
MCRegister Reg = MAI->getFuncReg(Func);
if (!Reg.isValid()) {
std::string DiagMsg;
raw_string_ostream OS(DiagMsg);
Expand Down
9 changes: 5 additions & 4 deletions llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void SPIRVMCInstLower::lower(const MachineInstr *MI, MCInst &OutMI,
default:
llvm_unreachable("unknown operand type");
case MachineOperand::MO_GlobalAddress: {
Register FuncReg = MAI->getFuncReg(dyn_cast<Function>(MO.getGlobal()));
MCRegister FuncReg = MAI->getFuncReg(dyn_cast<Function>(MO.getGlobal()));
if (!FuncReg.isValid()) {
std::string DiagMsg;
raw_string_ostream OS(DiagMsg);
Expand All @@ -49,13 +49,14 @@ void SPIRVMCInstLower::lower(const MachineInstr *MI, MCInst &OutMI,
MCOp = MCOperand::createReg(MAI->getOrCreateMBBRegister(*MO.getMBB()));
break;
case MachineOperand::MO_Register: {
Register NewReg = MAI->getRegisterAlias(MF, MO.getReg());
MCOp = MCOperand::createReg(NewReg.isValid() ? NewReg : MO.getReg());
MCRegister NewReg = MAI->getRegisterAlias(MF, MO.getReg());
MCOp = MCOperand::createReg(NewReg.isValid() ? NewReg
: MO.getReg().asMCReg());
break;
}
case MachineOperand::MO_Immediate:
if (MI->getOpcode() == SPIRV::OpExtInst && i == 2) {
Register Reg = MAI->getExtInstSetReg(MO.getImm());
MCRegister Reg = MAI->getExtInstSetReg(MO.getImm());
MCOp = MCOperand::createReg(Reg);
} else {
MCOp = MCOperand::createImm(MO.getImm());
Expand Down
33 changes: 17 additions & 16 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
if (ST->isOpenCLEnv()) {
// TODO: check if it's required by default.
MAI.ExtInstSetMap[static_cast<unsigned>(
SPIRV::InstructionSet::OpenCL_std)] =
Register::index2VirtReg(MAI.getNextID());
SPIRV::InstructionSet::OpenCL_std)] = MAI.getNextIDRegister();
}
}

Expand Down Expand Up @@ -306,7 +305,8 @@ void SPIRVModuleAnalysis::visitFunPtrUse(
} while (OpDefMI && (OpDefMI->getOpcode() == SPIRV::OpFunction ||
OpDefMI->getOpcode() == SPIRV::OpFunctionParameter));
// associate the function pointer with the newly assigned global number
Register GlobalFunDefReg = MAI.getRegisterAlias(FunDefMF, OpFunDef->getReg());
MCRegister GlobalFunDefReg =
MAI.getRegisterAlias(FunDefMF, OpFunDef->getReg());
assert(GlobalFunDefReg.isValid() &&
"Function definition must refer to a global register");
MAI.setRegisterAlias(MF, OpReg, GlobalFunDefReg);
Expand Down Expand Up @@ -353,10 +353,10 @@ void SPIRVModuleAnalysis::visitDecl(
"No unique definition is found for the virtual register");
}

Register GReg;
MCRegister GReg;
bool IsFunDef = false;
if (TII->isSpecConstantInstr(MI)) {
GReg = Register::index2VirtReg(MAI.getNextID());
GReg = MAI.getNextIDRegister();
MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
} else if (Opcode == SPIRV::OpFunction ||
Opcode == SPIRV::OpFunctionParameter) {
Expand All @@ -366,7 +366,7 @@ void SPIRVModuleAnalysis::visitDecl(
const MachineInstr *NextInstr = MI.getNextNode();
while (NextInstr &&
NextInstr->getOpcode() == SPIRV::OpTypeStructContinuedINTEL) {
Register Tmp = handleTypeDeclOrConstant(*NextInstr, SignatureToGReg);
MCRegister Tmp = handleTypeDeclOrConstant(*NextInstr, SignatureToGReg);
MAI.setRegisterAlias(MF, NextInstr->getOperand(0).getReg(), Tmp);
MAI.setSkipEmission(NextInstr);
NextInstr = NextInstr->getNextNode();
Expand All @@ -389,7 +389,7 @@ void SPIRVModuleAnalysis::visitDecl(
MAI.setSkipEmission(&MI);
}

Register SPIRVModuleAnalysis::handleFunctionOrParameter(
MCRegister SPIRVModuleAnalysis::handleFunctionOrParameter(
const MachineFunction *MF, const MachineInstr &MI,
std::map<const Value *, unsigned> &GlobalToGReg, bool &IsFunDef) {
const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg());
Expand All @@ -402,27 +402,27 @@ Register SPIRVModuleAnalysis::handleFunctionOrParameter(
auto It = GlobalToGReg.find(GObj);
if (It != GlobalToGReg.end())
return It->second;
Register GReg = Register::index2VirtReg(MAI.getNextID());
MCRegister GReg = MAI.getNextIDRegister();
GlobalToGReg[GObj] = GReg;
if (!IsFunDef)
MAI.MS[SPIRV::MB_ExtFuncDecls].push_back(&MI);
return GReg;
}

Register
MCRegister
SPIRVModuleAnalysis::handleTypeDeclOrConstant(const MachineInstr &MI,
InstrGRegsMap &SignatureToGReg) {
InstrSignature MISign = instrToSignature(MI, MAI, false);
auto It = SignatureToGReg.find(MISign);
if (It != SignatureToGReg.end())
return It->second;
Register GReg = Register::index2VirtReg(MAI.getNextID());
MCRegister GReg = MAI.getNextIDRegister();
SignatureToGReg[MISign] = GReg;
MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
return GReg;
}

Register SPIRVModuleAnalysis::handleVariable(
MCRegister SPIRVModuleAnalysis::handleVariable(
const MachineFunction *MF, const MachineInstr &MI,
std::map<const Value *, unsigned> &GlobalToGReg) {
MAI.GlobalVarList.push_back(&MI);
Expand All @@ -431,7 +431,7 @@ Register SPIRVModuleAnalysis::handleVariable(
auto It = GlobalToGReg.find(GObj);
if (It != GlobalToGReg.end())
return It->second;
Register GReg = Register::index2VirtReg(MAI.getNextID());
MCRegister GReg = MAI.getNextIDRegister();
GlobalToGReg[GObj] = GReg;
MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
return GReg;
Expand Down Expand Up @@ -507,7 +507,7 @@ void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
} else if (MI.getOpcode() == SPIRV::OpFunction) {
// Record all internal OpFunction declarations.
Register Reg = MI.defs().begin()->getReg();
Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
MCRegister GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
assert(GlobalReg.isValid());
MAI.FuncMap[F] = GlobalReg;
}
Expand Down Expand Up @@ -599,14 +599,14 @@ void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
Register Reg = Op.getReg();
if (MAI.hasRegisterAlias(MF, Reg))
continue;
Register NewReg = Register::index2VirtReg(MAI.getNextID());
MCRegister NewReg = MAI.getNextIDRegister();
MAI.setRegisterAlias(MF, Reg, NewReg);
}
if (MI.getOpcode() != SPIRV::OpExtInst)
continue;
auto Set = MI.getOperand(2).getImm();
if (!MAI.ExtInstSetMap.contains(Set))
MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID());
MAI.ExtInstSetMap[Set] = MAI.getNextIDRegister();
}
}
}
Expand Down Expand Up @@ -1938,7 +1938,7 @@ static void addMBBNames(const Module &M, const SPIRVInstrInfo &TII,
Register Reg = MRI.createGenericVirtualRegister(LLT::scalar(64));
MRI.setRegClass(Reg, &SPIRV::IDRegClass);
buildOpName(Reg, MBB.getName(), *std::prev(MBB.end()), TII);
Register GlobalReg = MAI.getOrCreateMBBRegister(MBB);
MCRegister GlobalReg = MAI.getOrCreateMBBRegister(MBB);
MAI.setRegisterAlias(MF, Reg, GlobalReg);
}
}
Expand Down Expand Up @@ -1992,6 +1992,7 @@ bool SPIRVModuleAnalysis::runOnModule(Module &M) {

// Process type/const/global var/func decl instructions, number their
// destination registers from 0 to N, collect Extensions and Capabilities.
collectReqs(M, MAI, MMI, *ST);
collectDeclarations(M);

// Number rest of registers from N+1 onwards.
Expand Down
Loading
Loading