Skip to content

Commit

Permalink
feat: register state argument
Browse files Browse the repository at this point in the history
  • Loading branch information
ajlekcahdp4 committed Nov 17, 2024
1 parent 1a5011d commit 00d27be
Show file tree
Hide file tree
Showing 11 changed files with 579 additions and 156 deletions.
6 changes: 4 additions & 2 deletions include/bleach/lifter/block-ir-builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class mbb2bb final : private DenseMap<const MachineBasicBlock *, BasicBlock *> {
}
};

void fill_ir_for_bb(MachineBasicBlock &mbb, Function &func, reg2vals &rmap,
void fill_ir_for_bb(MachineBasicBlock &mbb, reg2vals &rmap,
const instr_impl &instrs, const LLVMTargetMachine &tm,
const target &tgt, const mbb2bb &m2b);
struct basic_block {
Expand All @@ -67,6 +67,8 @@ struct basic_block {

void copy_instructions(const MachineBasicBlock &src, MachineBasicBlock &dst);

basic_block clone_basic_block(MachineBasicBlock &src);
basic_block clone_basic_block(MachineBasicBlock &src, MachineFunction &dst);

StructType &create_state_type(LLVMContext &ctx);

} // namespace bleach::lifter
143 changes: 82 additions & 61 deletions lib/lifter/block-ir-builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ void copy_instructions(const MachineBasicBlock &src, MachineBasicBlock &dst) {
}
}

basic_block clone_basic_block(MachineBasicBlock &src) {
auto *mf = src.getParent();
assert(mf);
auto &func = mf->getFunction();
basic_block clone_basic_block(MachineBasicBlock &src, MachineFunction &dst) {
auto &func = dst.getFunction();
auto *new_block = BasicBlock::Create(func.getContext(), "", &func);
assert(new_block);
// Dummy basic block is necessary due to bug in
Expand All @@ -49,49 +47,40 @@ basic_block clone_basic_block(MachineBasicBlock &src) {
builder.SetInsertPoint(new_block);
builder.CreateBr(dummy_bb);

auto *new_mblock = mf->CreateMachineBasicBlock(new_block);
auto *new_mblock = dst.CreateMachineBasicBlock(new_block);
assert(new_mblock);
mf->push_back(new_mblock);
dst.push_back(new_mblock);
copy_instructions(src, *new_mblock);

for (auto it = src.succ_begin(); it != src.succ_end(); ++it)
new_mblock->copySuccessor(&src, it);
// remove dummy instruction
assert(!new_block->empty());
new_block->begin()->eraseFromParent();
dummy_bb->eraseFromParent();
return {new_mblock, new_block};
}

void create_basic_blocks_for_mfunc(MachineFunction &mfunc, mbb2bb &m2b) {
std::vector<MachineBasicBlock *> mblocks_to_erase;
std::vector<BasicBlock *> blocks_to_erase;
transform(mfunc, std::back_inserter(mblocks_to_erase),
[](auto &mbb) { return &mbb; });
auto &func = mfunc.getFunction();
transform(func, std::back_inserter(blocks_to_erase),
[](auto &bb) { return &bb; });

void create_basic_blocks_for_mfunc(MachineFunction &src, MachineFunction &dst,
mbb2bb &m2b) {
std::unordered_map<MachineBasicBlock *, MachineBasicBlock *> block_map;
// One cannot simply loop over mfunc as we insert new blocks into it
for (auto *mbb : mblocks_to_erase) {
auto [new_mblock, new_block] = clone_basic_block(*mbb);
for (auto &mbb : src) {
auto [new_mblock, new_block] = clone_basic_block(mbb, dst);
m2b.insert({new_mblock, new_block});
block_map.insert({mbb, new_mblock});
block_map.insert({&mbb, new_mblock});
}
for (auto [old_block, new_block] : block_map) {
for (auto &mbb : mfunc) {
for (auto &mbb : dst) {
if (&mbb == old_block)
continue;
if (!is_contained(mbb.successors(), old_block))
continue;
mbb.ReplaceUsesOfBlockWith(old_block, new_block);
}
}
for (auto *mbb : mblocks_to_erase) {
m2b.erase(mbb);
mbb->eraseFromParent();
}
for (auto *bb : blocks_to_erase)
bb->eraseFromParent();
for (auto &mbb : src)
m2b.erase(&mbb);
}

void fill_module_with_instrs(Module &m, const instr_impl &instrs) {
Expand All @@ -105,66 +94,98 @@ void fill_module_with_instrs(Module &m, const instr_impl &instrs) {
Linker::linkModules(m, std::move(first));
}

auto *generate_function_object(Module &m, MachineFunction &mf, reg2vals &rmap) {
void materialize_registers(MachineFunction &mf, Function &func, reg2vals &rmap,
const LLVMTargetMachine &target_machine,
StructType &state) {
if (mf.empty())
return;
auto &ctx = func.getContext();
constexpr auto gpr_array_idx = 0u;
Value *state_arg = func.getArg(gpr_array_idx);
auto *array_type = *state.element_begin();
assert(!func.empty());
auto &first_block = func.front();
auto builder = IRBuilder(ctx);
builder.SetInsertPoint(&first_block);
// pointer to a single state
// GPR array is the first field
auto *const_zero = ConstantInt::get(ctx, APInt(64, 0));
auto *array_ptr = builder.CreateGEP(
&state, state_arg, ArrayRef<Value *>{const_zero, const_zero}, "GPRS");
auto *reg_info = target_machine.getMCRegisterInfo();
assert(reg_info);
// TODO: get GPR reg class index propperly
constexpr auto gpr_class_idx = 3u;
auto &gpr_class = reg_info->getRegClass(gpr_class_idx);
auto sorted_regs = std::set<unsigned>();
ranges::copy(gpr_class, std::inserter(sorted_regs, sorted_regs.end()));
for (auto [idx, reg] : ranges::views::enumerate(sorted_regs)) {
auto *array_idx = ConstantInt::get(ctx, APInt(64, idx));
auto *reg_addr = builder.CreateInBoundsGEP(
array_type, array_ptr, ArrayRef<Value *>{const_zero, array_idx});
auto *reg_value = builder.CreateLoad(Type::getIntNTy(ctx, 64), reg_addr,
reg_info->getName(reg));
rmap.try_emplace(reg, reg_value);
}
}

auto *generate_function_object(Module &m, MachineFunction &mf, reg2vals &rmap,
MachineModuleInfo &mmi, StructType &state) {
auto *ret_type = Type::getIntNTy(m.getContext(), 64);
auto &reg_info = mf.getRegInfo();
std::vector<Type *> args;
for ([[maybe_unused]] auto &&_ : reg_info.liveins())
args.push_back(Type::getIntNTy(m.getContext(), 64));
auto *func_type = FunctionType::get(ret_type, args, /* is var arg */ false);
auto *func_type = FunctionType::get(
ret_type, ArrayRef<Type *>{PointerType::getUnqual(m.getContext())},
/* is var arg */ false);
auto *func =
Function::Create(func_type, Function::ExternalLinkage, mf.getName(), m);
for (auto &&[livein, arg] : zip(reg_info.liveins(), func->args()))
rmap.try_emplace(livein.first, &arg);
return func;
}

auto generate_function(Module &m, MachineFunction &mf, const instr_impl &instrs,
const LLVMTargetMachine &target_machine,
const target &tgt, const mbb2bb &m2b) {
MachineModuleInfo &mmi, const target &tgt,
StructType &state) {
reg2vals rmap;
auto *func = generate_function_object(m, mf, rmap);
auto &f = mf.getFunction();
for (auto &bb : make_early_inc_range(f)) {
auto known_blocks = make_second_range(m2b);
bb.erase(bb.begin(), bb.end());
if (!is_contained(known_blocks, &bb))
bb.eraseFromParent();
}
for (auto &mbb : mf)
fill_ir_for_bb(mbb, *func, rmap, instrs, target_machine, tgt, m2b);
mbb2bb m2b;
auto *func = generate_function_object(m, mf, rmap, mmi, state);
auto &dst = mmi.getOrCreateMachineFunction(*func);
create_basic_blocks_for_mfunc(mf, dst, m2b);
materialize_registers(mf, *func, rmap, mmi.getTarget(), state);
for (auto &mbb : dst)
fill_ir_for_bb(mbb, rmap, instrs, mmi.getTarget(), tgt, m2b);
}

std::string get_instruction_name(const MachineInstr &minst,
const MCInstrInfo &instr_info) {
return instr_info.getName(minst.getOpcode()).str();
}

void create_basic_blocks(Module &m, MachineModuleInfo &mmi, mbb2bb &m2b,
const std::set<std::string> &target_functions) {
for (auto &f : m | ranges::views::filter([&target_functions](auto &f) {
return target_functions.contains(f.getName().str());
})) {
auto &mf = mmi.getOrCreateMachineFunction(f);
create_basic_blocks_for_mfunc(mf, m2b);
}
StructType &create_state_type(LLVMContext &ctx) {
// FIXME: register number and width should not be hardcoded. They can be
// determined from input MIR.
// GPR registers
auto *array_type =
ArrayType::get(Type::getIntNTy(ctx, 64), /*register number*/ 32);
auto *struct_type =
StructType::create(ctx, ArrayRef<Type *>{array_type}, "register_state");
assert(struct_type);
return *struct_type;
}

PreservedAnalyses block_ir_builder_pass::run(Module &m,
ModuleAnalysisManager &mam) {
std::set<std::string> target_functions;
std::set<Function *> target_functions;
transform(m, std::inserter(target_functions, target_functions.end()),
[](auto &f) { return f.getName().str(); });
[](auto &f) { return &f; });
fill_module_with_instrs(m, instrs);
auto m2b = mbb2bb{};
auto &mmi = mam.getResult<MachineModuleAnalysis>(m).getMMI();
create_basic_blocks(m, mmi, m2b, target_functions);
for (auto &f : m | ranges::views::filter([&target_functions](auto &f) {
return target_functions.contains(f.getName().str());
})) {
auto &mf = mmi.getOrCreateMachineFunction(f);
generate_function(m, mf, instrs, mmi.getTarget(), tgt, m2b);
auto &state = create_state_type(m.getContext());
for (auto *f : target_functions) {
auto &mf = mmi.getOrCreateMachineFunction(*f);
generate_function(m, mf, instrs, mmi, tgt, state);
}
for (auto *f : target_functions)
f->eraseFromParent();
return PreservedAnalyses::none();
}

Expand Down Expand Up @@ -256,7 +277,7 @@ auto generate_instruction(const MachineInstr &minst, BasicBlock &bb,
return call;
}

void fill_ir_for_bb(MachineBasicBlock &mbb, Function &func, reg2vals &rmap,
void fill_ir_for_bb(MachineBasicBlock &mbb, reg2vals &rmap,
const instr_impl &instrs,
const LLVMTargetMachine &target_machine, const target &tgt,
const mbb2bb &m2b) {
Expand Down
1 change: 1 addition & 0 deletions lib/target/riscv/riscv-target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#define GET_INSTRINFO_ENUM
#include "RISCVGenInstrInfo.inc"
#include "RISCVGenRegisterInfo.inc"
#undef GET_INSTRINFO_ENUM

namespace bleach {
Expand Down
2 changes: 1 addition & 1 deletion test/lib/lifter/lifter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ TEST(lifter_basic, clone_basic_block) {
for (auto &f : *m) {
auto &mf = machine_module_info->getOrCreateMachineFunction(f);
for (auto &mbb : make_early_inc_range(mf)) {
auto [copy_mbb, copy_bb] = bleach::lifter::clone_basic_block(mbb);
auto [copy_mbb, copy_bb] = bleach::lifter::clone_basic_block(mbb, mf);
for (auto &&[inst, copy_inst] : zip(mbb, *copy_mbb)) {
std::string str1;
raw_string_ostream os1(str1);
Expand Down
80 changes: 76 additions & 4 deletions test/tools/llvm-bleach/ir-addsub.test
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,81 @@
# RUN: --instructions=%S/inputs/addsub.yaml | filecheck %s


# CHECK: define dso_local signext void @foo(i64 noundef signext %0, i64 noundef signext %1) local_unnamed_addr {
# CHECK-NEXT: %3 = call i64 @ADDW(i64 %1, i64 %0)
# CHECK-NEXT: %4 = call i64 @ADDW(i64 %3, i64 %3)
# CHECK-NEXT: %5 = call i64 @SUBW(i64 %4, i64 %3)
; ModuleID = '/home/alexander/projects/llvm-bleach/test/tools/llvm-bleach/inputs/foo.mir'
source_filename = "foo.c"
target datalayout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128"
target triple = "riscv64-unknown-linux-gnu"

# CHECK: %register_state = type { [32 x i64] }

# CHECK: define i64 @foo.1(ptr %0) {
# CHECK-NEXT: %GPRS = getelementptr %register_state, ptr %0, i64 0, i64 0
# CHECK-NEXT: %2 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 0
# CHECK-NEXT: %X0 = load i64, ptr %2, align 8
# CHECK-NEXT: %3 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 1
# CHECK-NEXT: %X1 = load i64, ptr %3, align 8
# CHECK-NEXT: %4 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 2
# CHECK-NEXT: %X2 = load i64, ptr %4, align 8
# CHECK-NEXT: %5 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 3
# CHECK-NEXT: %X3 = load i64, ptr %5, align 8
# CHECK-NEXT: %6 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 4
# CHECK-NEXT: %X4 = load i64, ptr %6, align 8
# CHECK-NEXT: %7 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 5
# CHECK-NEXT: %X5 = load i64, ptr %7, align 8
# CHECK-NEXT: %8 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 6
# CHECK-NEXT: %X6 = load i64, ptr %8, align 8
# CHECK-NEXT: %9 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 7
# CHECK-NEXT: %X7 = load i64, ptr %9, align 8
# CHECK-NEXT: %10 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 8
# CHECK-NEXT: %X8 = load i64, ptr %10, align 8
# CHECK-NEXT: %11 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 9
# CHECK-NEXT: %X9 = load i64, ptr %11, align 8
# CHECK-NEXT: %12 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 10
# CHECK-NEXT: %X10 = load i64, ptr %12, align 8
# CHECK-NEXT: %13 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 11
# CHECK-NEXT: %X11 = load i64, ptr %13, align 8
# CHECK-NEXT: %14 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 12
# CHECK-NEXT: %X12 = load i64, ptr %14, align 8
# CHECK-NEXT: %15 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 13
# CHECK-NEXT: %X13 = load i64, ptr %15, align 8
# CHECK-NEXT: %16 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 14
# CHECK-NEXT: %X14 = load i64, ptr %16, align 8
# CHECK-NEXT: %17 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 15
# CHECK-NEXT: %X15 = load i64, ptr %17, align 8
# CHECK-NEXT: %18 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 16
# CHECK-NEXT: %X16 = load i64, ptr %18, align 8
# CHECK-NEXT: %19 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 17
# CHECK-NEXT: %X17 = load i64, ptr %19, align 8
# CHECK-NEXT: %20 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 18
# CHECK-NEXT: %X18 = load i64, ptr %20, align 8
# CHECK-NEXT: %21 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 19
# CHECK-NEXT: %X19 = load i64, ptr %21, align 8
# CHECK-NEXT: %22 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 20
# CHECK-NEXT: %X20 = load i64, ptr %22, align 8
# CHECK-NEXT: %23 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 21
# CHECK-NEXT: %X21 = load i64, ptr %23, align 8
# CHECK-NEXT: %24 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 22
# CHECK-NEXT: %X22 = load i64, ptr %24, align 8
# CHECK-NEXT: %25 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 23
# CHECK-NEXT: %X23 = load i64, ptr %25, align 8
# CHECK-NEXT: %26 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 24
# CHECK-NEXT: %X24 = load i64, ptr %26, align 8
# CHECK-NEXT: %27 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 25
# CHECK-NEXT: %X25 = load i64, ptr %27, align 8
# CHECK-NEXT: %28 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 26
# CHECK-NEXT: %X26 = load i64, ptr %28, align 8
# CHECK-NEXT: %29 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 27
# CHECK-NEXT: %X27 = load i64, ptr %29, align 8
# CHECK-NEXT: %30 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 28
# CHECK-NEXT: %X28 = load i64, ptr %30, align 8
# CHECK-NEXT: %31 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 29
# CHECK-NEXT: %X29 = load i64, ptr %31, align 8
# CHECK-NEXT: %32 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 30
# CHECK-NEXT: %X30 = load i64, ptr %32, align 8
# CHECK-NEXT: %33 = getelementptr inbounds [32 x i64], ptr %GPRS, i64 0, i64 31
# CHECK-NEXT: %X31 = load i64, ptr %33, align 8
# CHECK-NEXT: %34 = call i64 @ADDW(i64 %X11, i64 %X10)
# CHECK-NEXT: %35 = call i64 @ADDW(i64 %34, i64 %34)
# CHECK-NEXT: %36 = call i64 @SUBW(i64 %35, i64 %34)
# CHECK-NEXT: }

Loading

0 comments on commit 00d27be

Please sign in to comment.