Skip to content

Commit

Permalink
feat: correct instructions, incorrect order
Browse files Browse the repository at this point in the history
  • Loading branch information
ajlekcahdp4 committed Oct 19, 2024
1 parent ec8e873 commit ebb7f7b
Show file tree
Hide file tree
Showing 3 changed files with 827 additions and 54 deletions.
196 changes: 142 additions & 54 deletions lib/llvm-api-gen/llvm-api-gen.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "llvm-api-gen/llvm-api-gen.h"

#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/Metadata.h>
#include <llvm/IR/PassManager.h>
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Passes/PassPlugin.h>
Expand All @@ -15,9 +17,19 @@ using namespace std::string_literals;

constexpr auto builder = "builder";

std::string get_type_str(const Type &type, StringRef ctx_name) {
struct generation_context final {
std::unordered_set<const Type *> defined_types;
std::unordered_map<const Value *, unsigned> defined_values;
std::unordered_map<const PHINode *, const Instruction *> phis;
};

std::string get_type_str(const Type &type, StringRef ctx_name,
generation_context &ctx) {
std::string ret;
raw_string_ostream os(ret);
auto [It, Inserted] = ctx.defined_types.insert(&type);
if (!Inserted)
return "";
switch (type.getTypeID()) {
case Type::IntegerTyID: {
auto num = cast<IntegerType>(&type)->getBitWidth();
Expand All @@ -41,7 +53,7 @@ std::string get_type_str(const Type &type, StringRef ctx_name) {
auto *elem_type = array_type->getElementType();
assert(elem_type);
auto num = array_type->getNumElements();
os << get_type_str(*elem_type, ctx_name);
os << get_type_str(*elem_type, ctx_name, ctx);
os << formatv("auto *type_{0} = ArrayType::get(type_{1}, {2});\n", &type,
elem_type, num);
return ret;
Expand All @@ -54,37 +66,37 @@ std::string get_type_str(const Type &type, StringRef ctx_name) {
}
}

std::string get_ret_type(const Function &f) {
std::string get_ret_type(const Function &f, generation_context &ctx) {
auto *func_type = f.getFunctionType();
assert(func_type);
auto *ret_type = func_type->getReturnType();
assert(ret_type);
std::string tp;
raw_string_ostream os(tp);
os << get_type_str(*ret_type, "Ctx");
os << get_type_str(*ret_type, "Ctx", ctx);
os << formatv("auto *ret_type_{0} = type_{1};\n", &f, ret_type);
return tp;
}

std::string get_args_types(const Function &f) {
std::string get_args_types(const Function &f, generation_context &ctx) {
std::string create_args;
raw_string_ostream os(create_args);
os << formatv("std::vector<Type*> args_{0};\n", &f);
auto *func_type = f.getFunctionType();
assert(func_type);
for (auto *t : func_type->params()) {
assert(t);
os << get_type_str(*t, "Ctx");
os << get_type_str(*t, "Ctx", ctx);
os << formatv("args_{0}.push_back(type_{1});\n", &f, t);
}
return create_args;
}

std::string get_func_type(const Function &f) {
std::string get_func_type(const Function &f, generation_context &ctx) {
std::string create_func_type;
raw_string_ostream os(create_func_type);
os << get_ret_type(f);
os << get_args_types(f);
os << get_ret_type(f, ctx);
os << get_args_types(f, ctx);
os << formatv("auto *func_type_{0} = FunctionType::get(ret_type_{0}, "
"args_{0}, false);\n",
&f)
Expand Down Expand Up @@ -118,39 +130,41 @@ std::string get_instr_create_name(const Instruction &instr) {
CASE_INSTR(Br)
CASE_INSTR(ICmp)
CASE_INSTR(Select)
CASE_INSTR(GetElementPtr)
CASE_INSTR(Load)
CASE_INSTR(Store)
CASE_INSTR(Alloca)
CASE_INSTR(Call)
CASE_INSTR(Switch)
CASE_INSTR(PHI)
CASE_INSTR(Unreachable)
case Instruction::GetElementPtr:
return "GEP";
default:
return "UNSUPPORTED";
// llvm_unreachable("Unsupported instruction");
}
}
#undef CASE_INSTR

void create_operand(const Value &v, unsigned idx, raw_ostream &os) {
void create_operand(const Value &v, const Value &parent, unsigned idx,
raw_ostream &os, generation_context &ctx) {
if (auto *int_constant = dyn_cast<ConstantInt>(&v)) {
os << formatv(
"auto *op_{0}_{1} = ConstantInt::get(Ctx, APInt({2}, {3}));\n", idx, &v,
int_constant->getBitWidth(), int_constant->getZExtValue());
"auto *op_{0}_{1} = ConstantInt::get(Ctx, APInt({2}, {3}));\n", idx,
&parent, int_constant->getBitWidth(), int_constant->getZExtValue());
} else if (auto *bb = dyn_cast<BasicBlock>(&v)) {
os << formatv("auto *op_{0}_{1} = bb_{2};\n", idx, &v, bb);
os << formatv("auto *op_{0}_{1} = bb_{2};\n", idx, &parent, bb);
} else if (auto *func = dyn_cast<Function>(&v)) {
os << formatv("auto *op_{0}_{1} = func_{2};\n", idx, &v, func);
os << formatv("auto *op_{0}_{1} = func_{2};\n", idx, &parent, func);
} else if (auto *instr = dyn_cast<Instruction>(&v)) {
os << formatv("auto *op_{0}_{1} = instr_{2};\n", idx, &v, instr);
os << formatv("auto *op_{0}_{1} = instr_{2};\n", idx, &parent, instr);
} else if (auto *gv = dyn_cast<GlobalVariable>(&v)) {
auto *constant = gv->getInitializer();
if (auto *const_str = dyn_cast<ConstantDataSequential>(constant)) {
if (const_str->isString())
os << formatv("auto *op_{0}_{1} = ConstantDataArray::getString(Ctx, "
"\"{2}\", true);\n",
idx, &v, const_str->getAsString().drop_back());
idx, &parent, const_str->getAsString().drop_back());
else
os << "UNKNOWN\n";
}
Expand All @@ -161,35 +175,46 @@ void create_operand(const Value &v, unsigned idx, raw_ostream &os) {
os << "\n";
}
}

void create_phi_node(const PHINode &phi, raw_ostream &os) {
// Handle PHI nodes last
void create_phi_node(const PHINode &phi, raw_ostream &os,
generation_context &ctx) {
auto num_incoming = phi.getNumIncomingValues();
auto *type = phi.getType();
os << get_type_str(*type, "Ctx");
os << get_type_str(*type, "Ctx", ctx);
os << formatv("auto *phi_ty_{0} = type_{1};\n", &phi, type);
os << formatv("auto *phi_{0} = {1}.CreatePHI(phi_ty_{0}, {2}, \"\");\n", &phi,
builder, num_incoming);
for (auto &&[idx, pair] :
enumerate(zip(phi.incoming_values(), phi.blocks()))) {
auto &&[val, bb] = pair;
create_operand(*val.get(), idx, os);
create_operand(*val.get(), *static_cast<const Value *>(&phi), idx, os, ctx);
os << formatv("phi_{0}->addIncoming(op_{1}_{2}, bb_{3});\n", &phi, idx,
val.get(), &bb);
static_cast<const Value *>(&phi), &bb);
}
os << formatv("auto *instr_{0} = phi_{1};\n", dyn_cast<Instruction>(&phi),
&phi);
}

bool requires_special_handling(const Instruction &instr) {
return instr.getOpcode() == Instruction::Alloca;
switch (instr.getOpcode()) {
case Instruction::Alloca:
return true;
case Instruction::Call:
return true;
default:
return false;
}
}

void create_pre_args(const Instruction &instr, raw_ostream &os) {
void create_pre_args(const Instruction &instr, raw_ostream &os,
generation_context &ctx) {
switch (instr.getOpcode()) {
case Instruction::Alloca: {
auto *alloca = dyn_cast<AllocaInst>(&instr);
assert(alloca);
auto *allocated_type = alloca->getAllocatedType();
assert(allocated_type);
os << get_type_str(*allocated_type, "Ctx");
os << get_type_str(*allocated_type, "Ctx", ctx);
os << formatv("auto *add_arg_{0} = type_{1};\n", &instr, allocated_type);
break;
}
Expand All @@ -198,42 +223,98 @@ void create_pre_args(const Instruction &instr, raw_ostream &os) {
}
}

std::string create_instr(const Instruction &instr) {
void generate_call_create_instr(const Instruction &instr, raw_ostream &os,
generation_context &ctx,
std::optional<unsigned> num = std::nullopt) {
os << formatv("auto *instr_{0} = {1}.Create{2}(", &instr, builder,
get_instr_create_name(instr))
.str();
if (auto *_ = dyn_cast<AllocaInst>(&instr))
os << formatv("add_arg_{0}, ", &instr);

interleaveComma(
map_range(llvm::index_range(0, num.value_or(instr.getNumOperands())),
[&instr](auto idx) {
return formatv("op_{0}_{1}", idx,
static_cast<const Value *>(&instr))
.str();
}),
os);
os << formatv(");\n");
}

void generate_operands(const Instruction &instr, raw_ostream &os,
generation_context &ctx) {
for (auto &&[idx, op] : enumerate(instr.operands())) {
// if (idx == 0) continue; // drop_begin segfaults
auto *val = op.get();
assert(val);
create_operand(*val, *static_cast<const Value *>(&instr), idx, os, ctx);
}
}

void generate_special_instr(const Instruction &instr, raw_ostream &os,
generation_context &ctx) {
switch (instr.getOpcode()) {
case Instruction::Alloca: {
create_pre_args(instr, os, ctx);
generate_operands(instr, os, ctx);
generate_call_create_instr(instr, os, ctx);
return;
}
case Instruction::Call: {
auto is_function = [](auto &op) -> bool {
return dyn_cast<Function>(op.get());
};
auto func_it = llvm::find_if(instr.operands(), is_function);
assert(func_it != instr.operands().end());
auto *func = dyn_cast<Function>(func_it->get());
assert(func);
unsigned idx = 0;
os << formatv("auto *op_{0}_{1} = func_type_{2};\n", idx++,
static_cast<const Value *>(&instr), func);
os << formatv("auto *op_{0}_{1} = func_{2};\n", idx++,
static_cast<const Value *>(&instr), func);
os << formatv("std::vector<Value *> op_{0}_{1};\n", idx++,
static_cast<const Value *>(&instr));
auto first_arg_idx = idx;
auto non_functions =
llvm::make_filter_range(instr.operands(), std::not_fn(is_function));
for (auto &&op : non_functions) {
auto *val = op.get();
assert(val);
create_operand(*val, *static_cast<const Value *>(&instr), idx++, os, ctx);
}
for (auto i = first_arg_idx; i < idx; ++i) {
os << formatv("op_{0}_{1}.push_back(op_{2}_{1});\n", first_arg_idx - 1,
static_cast<const Value *>(&instr), i);
}
generate_call_create_instr(instr, os, ctx, first_arg_idx);
}
}
}

std::string create_instr(const Instruction &instr, generation_context &ctx) {
std::string instr_str;
raw_string_ostream os(instr_str);
// os << "\n\nINSTR:\n";
// instr.print(os);
// os << "\n";
if (auto *phi = dyn_cast<PHINode>(&instr)) {
create_phi_node(*phi, os);
auto [it, inserted] = ctx.phis.try_emplace(phi, instr.getNextNode());
assert(inserted);
return instr_str;
}

for (auto &&[idx, op] : enumerate(instr.operands())) {
// if (idx == 0) continue; // drop_begin segfaults
auto *val = op.get();
assert(val);
create_operand(*val, idx, os);
}
if (requires_special_handling(instr)) {
create_pre_args(instr, os);
generate_special_instr(instr, os, ctx);
} else {
generate_operands(instr, os, ctx);
generate_call_create_instr(instr, os, ctx);
}
os << formatv("auto *instr_{0} = {1}.Create{2}(", &instr, builder,
get_instr_create_name(instr))
.str();
if (requires_special_handling(instr))
os << formatv("add_arg_{0}, ", &instr);
interleaveComma(map_range(enumerate(instr.operands()),
[](auto &&pair) {
auto &&[idx, op] = pair;
return formatv("op_{0}_{1}", idx, op.get()).str();
}),
os);
os << formatv(");\n");
return instr_str;
}

std::string create_bb(const BasicBlock &bb) {
std::string create_bb(const BasicBlock &bb, generation_context &ctx) {
auto *f = bb.getParent();
assert(f);
std::string bb_str;
Expand All @@ -243,34 +324,41 @@ std::string create_bb(const BasicBlock &bb) {
os << formatv("{0}.SetInsertPoint(bb_{1});\n", builder, &bb);

for (auto &instr : bb) {
os << create_instr(instr);
os << create_instr(instr, ctx);
}
return bb_str;
}

std::string create_func(const Function &f) {
std::string create_func(const Function &f, generation_context &ctx) {
std::string func;
raw_string_ostream os(func);
os << get_func_type(f);
os << get_func_type(f, ctx);
os << formatv("auto *func_{0} = Function::Create(func_type_{0}, "
"Function::ExternalLinkage, \"{1}\", M);\n",
&f, f.getName())
.str();
for (auto &bb : f)
os << create_bb(bb);
os << create_bb(bb, ctx);
return func;
}

PreservedAnalyses api_gen_pass::run(Function &f, FunctionAnalysisManager &) {
std::unordered_set<std::string> visited;
generation_context ctx;
visited.insert(f.getName().str());
auto *m = f.getParent();
assert(m);
for (auto &ff : m->getFunctionList()) {
if (ff.getName() != f.getName() && !visited.contains(ff.getName().str()))
os << create_func(ff);
os << create_func(ff, ctx);
}
os << create_func(f, ctx);
for (auto [phi, ins] : ctx.phis) {
assert(phi);
assert(ins);
os << formatv("{0}.SetInsertPoint(instr_{1});\n", builder, ins);
create_phi_node(*phi, os, ctx);
}
os << create_func(f);
return PreservedAnalyses::all();
}

Expand Down
4 changes: 4 additions & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ project(
)

subdir('lib')

llvm_dep = dependency('llvm')

#executable('test-generated', 'tests/generated.cpp', dependencies: [llvm_dep])
Loading

0 comments on commit ebb7f7b

Please sign in to comment.