Skip to content

Commit

Permalink
For loops: Allow sharing variables with main program
Browse files Browse the repository at this point in the history
1. Determine which variables need to be shared with the loop callback
2. Pack pointers to them into a context struct
3. Pass pointer to the context struct to the callback function
4. In the callback, override the shared variables so that they read and
   write through the context pointers instead of directly from their
   original addresses

See the comment in semantic_analyser.cpp for pseudo code of this
transformation.
  • Loading branch information
ajor committed Jul 9, 2024
1 parent fe9c5b4 commit 32ca92e
Show file tree
Hide file tree
Showing 14 changed files with 601 additions and 92 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to
- [#3268](https://github.com/bpftrace/bpftrace/pull/3268)
- Enable for-loops in multiple probes
- [#3285](https://github.com/bpftrace/bpftrace/pull/3285)
- For-loops: Allow sharing variables between the main probe and the loop's body
- [#3014](https://github.com/bpftrace/bpftrace/pull/3014)
#### Changed
- Stream output when printing maps
- [#3264](https://github.com/bpftrace/bpftrace/pull/3264)
Expand Down
2 changes: 2 additions & 0 deletions src/ast/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,8 @@ class For : public Statement {
Expression *expr = nullptr;
StatementList *stmts = nullptr;

SizedType ctx_type;

private:
For(const For &other);
};
Expand Down
117 changes: 93 additions & 24 deletions src/ast/passes/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1426,10 +1426,10 @@ void CodegenLLVM::visit(Variable &var)
// Arrays and structs are not memcopied for local variables
if (needMemcpy(var.type) &&
!(var.type.IsArrayTy() || var.type.IsRecordTy())) {
expr_ = variables_[var.ident];
expr_ = variables_[var.ident].value;
} else {
auto *var_alloca = variables_[var.ident];
expr_ = b_.CreateLoad(var_alloca->getAllocatedType(), var_alloca);
auto &var_llvm = variables_[var.ident];
expr_ = b_.CreateLoad(var_llvm.type, var_llvm.value);
}
}

Expand Down Expand Up @@ -2289,22 +2289,22 @@ void CodegenLLVM::visit(AssignVarStatement &assignment)
}

AllocaInst *val = b_.CreateAllocaBPFInit(alloca_type, var.ident);
variables_[var.ident] = val;
variables_[var.ident] = VariableLLVM{ val, val->getAllocatedType() };
}

if (var.type.IsArrayTy() || var.type.IsRecordTy()) {
// For arrays and structs, only the pointer is stored
b_.CreateStore(b_.CreatePtrToInt(expr_, b_.getInt64Ty()),
variables_[var.ident]);
variables_[var.ident].value);
// Extend lifetime of RHS up to the end of probe
scoped_del.disarm();
} else if (needMemcpy(var.type)) {
auto *val = variables_[var.ident];
auto *val = variables_[var.ident].value;
if (assignment.expr->type.GetSize() != var.type.GetSize())
b_.CreateMemsetBPF(val, b_.getInt8(0), var.type.GetSize());
b_.CREATE_MEMCPY(val, expr_, assignment.expr->type.GetSize(), 1);
} else {
b_.CreateStore(expr_, variables_[var.ident]);
b_.CreateStore(expr_, variables_[var.ident].value);
}
}

Expand Down Expand Up @@ -2464,8 +2464,41 @@ void CodegenLLVM::visit(For &f)
auto &map = static_cast<Map &>(*f.expr);

Value *ctx = b_.getInt64(0);
llvm::Type *ctx_t = nullptr;

const auto &ctx_fields = f.ctx_type.GetFields();
if (!ctx_fields.empty()) {
// Pack pointers to variables into context struct for use in the callback

#if LLVM_VERSION_MAJOR < 15
std::vector<llvm::Type *> ctx_field_types;
ctx_field_types.reserve(ctx_fields.size());
for (const auto &field : ctx_fields) {
ctx_field_types.push_back(b_.GetType(field.type)->getPointerTo());
}
#else
std::vector<llvm::Type *> ctx_field_types(ctx_fields.size(),
b_.GET_PTR_TY());
#endif
ctx_t = b_.GetStructType("ctx_t", ctx_field_types);
ctx = b_.CreateAllocaBPF(ctx_t, "ctx");

for (size_t i = 0; i < ctx_fields.size(); i++) {
const auto &field = ctx_fields[i];
auto *field_expr = variables_[field.name].value;
auto *ctx_field_ptr = b_.CreateGEP(
ctx_t, ctx, { b_.getInt64(0), b_.getInt32(i) }, "ctx." + field.name);
#if LLVM_VERSION_MAJOR < 15
// An extra cast is required for older LLVM versions, pre-opaque-pointers
ctx_field_ptr = b_.CreatePointerCast(
ctx_field_ptr, field_expr->getType()->getPointerTo());
#endif
b_.CreateStore(field_expr, ctx_field_ptr);
}
}

b_.CreateForEachMapElem(
ctx_, map, createForEachMapCallback(map, *f.decl, *f.stmts), ctx, f.loc);
ctx_, map, createForEachMapCallback(f, ctx_t), ctx, f.loc);
}

void CodegenLLVM::visit(Predicate &pred)
Expand Down Expand Up @@ -2617,7 +2650,8 @@ void CodegenLLVM::visit(Subprog &subprog)
for (SubprogArg *arg : *subprog.args) {
auto alloca = b_.CreateAllocaBPF(b_.GetType(arg->type), arg->name());
b_.CreateStore(func->getArg(arg_index + 1), alloca);
variables_.insert({ arg->name(), alloca });
variables_[arg->name()] = VariableLLVM{ alloca,
alloca->getAllocatedType() };
++arg_index;
}

Expand Down Expand Up @@ -3917,14 +3951,14 @@ void CodegenLLVM::createIncDec(Unop &unop)
b_.CreateLifetimeEnd(newval);
} else if (unop.expr->is_variable) {
Variable &var = static_cast<Variable &>(*unop.expr);
Value *oldval = b_.CreateLoad(variables_[var.ident]->getAllocatedType(),
variables_[var.ident]);
Value *oldval = b_.CreateLoad(variables_[var.ident].type,
variables_[var.ident].value);
Value *newval;
if (is_increment)
newval = b_.CreateAdd(oldval, b_.GetIntSameSize(step, oldval));
else
newval = b_.CreateSub(oldval, b_.GetIntSameSize(step, oldval));
b_.CreateStore(newval, variables_[var.ident]);
b_.CreateStore(newval, variables_[var.ident].value);

if (unop.is_post_op)
expr_ = oldval;
Expand Down Expand Up @@ -4139,10 +4173,7 @@ Function *CodegenLLVM::createMapLenCallback()
return callback;
}

Function *CodegenLLVM::createForEachMapCallback(
Map &map,
const Variable &decl,
const std::vector<Statement *> &stmts)
Function *CodegenLLVM::createForEachMapCallback(const For &f, llvm::Type *ctx_t)
{
/*
* Create a callback function suitable for passing to bpf_for_each_map_elem,
Expand All @@ -4157,9 +4188,15 @@ Function *CodegenLLVM::createForEachMapCallback(

auto saved_ip = b_.saveIP();

#if LLVM_VERSION_MAJOR < 15
llvm::Type *ctx_ptr_ty = ctx_t ? ctx_t->getPointerTo() : b_.GET_PTR_TY();
#else
llvm::Type *ctx_ptr_ty = b_.GET_PTR_TY();
#endif
std::array<llvm::Type *, 4> args = {
b_.GET_PTR_TY(), b_.GET_PTR_TY(), b_.GET_PTR_TY(), b_.GET_PTR_TY()
b_.GET_PTR_TY(), b_.GET_PTR_TY(), b_.GET_PTR_TY(), ctx_ptr_ty
};

FunctionType *callback_type = FunctionType::get(b_.getInt64Ty(), args, false);
Function *callback = Function::Create(callback_type,
Function::LinkageTypes::InternalLinkage,
Expand All @@ -4173,18 +4210,19 @@ Function *CodegenLLVM::createForEachMapCallback(
auto *bb = BasicBlock::Create(module_->getContext(), "", callback);
b_.SetInsertPoint(bb);

auto &key_type = decl.type.GetField(0).type;
auto &key_type = f.decl->type.GetField(0).type;
Value *key = callback->getArg(1);
if (!onStack(key_type)) {
key = b_.CreateLoad(b_.GetType(key_type), key, "key");
}

auto &map = static_cast<Map &>(*f.expr);
auto map_info = bpftrace_.resources.maps_info.find(map.ident);
if (map_info == bpftrace_.resources.maps_info.end()) {
LOG(BUG) << "map name: \"" << map.ident << "\" not found";
}

auto &val_type = decl.type.GetField(1).type;
auto &val_type = f.decl->type.GetField(1).type;
Value *val = callback->getArg(2);

const auto &map_val_type = map_info->second.value_type;
Expand All @@ -4203,23 +4241,54 @@ Function *CodegenLLVM::createForEachMapCallback(
// used before. This is a hack to simulate block scoping in the absence of the
// real thing (#3017).
CollectNodes<Variable> new_vars;
for (auto *stmt : stmts) {
for (auto *stmt : *f.stmts) {
new_vars.run(*stmt, [this](const auto &var) {
return variables_.find(var.ident) == variables_.end();
});
}

// Create decl variable for use in this iteration of the loop
variables_[decl.ident] = createTuple(
decl.type, { { key, &decl.loc }, { val, &decl.loc } }, decl.ident);
AllocaInst *tuple = createTuple(f.decl->type,
{ { key, &f.decl->loc },
{ val, &f.decl->loc } },
f.decl->ident);
variables_[f.decl->ident] = VariableLLVM{ tuple, tuple->getAllocatedType() };

// 1. Save original locations of variables which will form part of the
// callback context
// 2. Replace variable expressions with those from the context
Value *ctx = callback->getArg(3);
const auto &ctx_fields = f.ctx_type.GetFields();
std::unordered_map<std::string, Value *> orig_ctx_vars;
for (size_t i = 0; i < ctx_fields.size(); i++) {
const auto &field = ctx_fields[i];
orig_ctx_vars[field.name] = variables_[field.name].value;

auto *ctx_field_ptr = b_.CreateGEP(
ctx_t, ctx, { b_.getInt64(0), b_.getInt32(i) }, "ctx." + field.name);
#if LLVM_VERSION_MAJOR < 15
auto *field_ty = variables_[field.name].value->getType();
#else
auto *field_ty = b_.GET_PTR_TY();
#endif
variables_[field.name].value = b_.CreateLoad(field_ty,
ctx_field_ptr,
field.name);
}

for (Statement *stmt : stmts) {
// Generate code for the loop body
for (Statement *stmt : *f.stmts) {
auto scoped_del = accept(stmt);
}
b_.CreateRet(b_.getInt64(0));

// Restore original non-context variables
for (const auto &[ident, expr] : orig_ctx_vars) {
variables_[ident].value = expr;
}

// Decl variable is not valid beyond this for loop
variables_.erase(decl.ident);
variables_.erase(f.decl->ident);

// Variables declared in a for-loop are not valid beyond it
for (const Variable &var : new_vars.nodes()) {
Expand Down
11 changes: 7 additions & 4 deletions src/ast/passes/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,7 @@ class CodegenLLVM : public Visitor {
void createIncDec(Unop &unop);

Function *createMapLenCallback();
Function *createForEachMapCallback(Map &map,
const Variable &decl,
const std::vector<Statement *> &stmts);
Function *createForEachMapCallback(const For &f, llvm::Type *ctx_t);
Function *createMurmurHash2Func();

Value *createFmtString(int print_id);
Expand Down Expand Up @@ -275,7 +273,12 @@ class CodegenLLVM : public Visitor {
int current_usdt_location_index_{ 0 };
bool inside_subprog_ = false;

std::map<std::string, AllocaInst *> variables_;
struct VariableLLVM {
llvm::Value *value;
llvm::Type *type;
};
std::unordered_map<std::string, VariableLLVM> variables_;

std::unordered_map<std::string, libbpf::bpf_map_type> map_types_;

Function *linear_func_ = nullptr;
Expand Down
10 changes: 10 additions & 0 deletions src/ast/passes/printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <sstream>

#include "ast/ast.h"
#include "struct.h"

namespace bpftrace {
namespace ast {
Expand Down Expand Up @@ -357,8 +358,17 @@ void Printer::visit(For &for_loop)
out_ << indent << "for" << std::endl;

++depth_;
if (for_loop.ctx_type.IsRecordTy() &&
!for_loop.ctx_type.GetFields().empty()) {
out_ << indent << " ctx\n";
for (const auto &field : for_loop.ctx_type.GetFields()) {
out_ << indent << " " << field.name << type(field.type) << "\n";
}
}

out_ << indent << " decl\n";
print(for_loop.decl);

out_ << indent << " expr\n";
print(for_loop.expr);

Expand Down
Loading

0 comments on commit 32ca92e

Please sign in to comment.