Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVM/RUNTIME] Support Parallel for on CPU #54

Merged
merged 1 commit into from
Feb 26, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,12 @@ LoweredFunc MakeAPI(Stmt body,
int num_unpacked_args);

/*!
* \brief Count number of undefined vars in f.
* \param f The function to be checked.
* \return Number of undefined vars.
* \brief Find undefined vars in the statment.
* \param stmt The function to be checked.
* \param defs The vars that is defined.
* \return Array of undefined vars.
*/
Array<Var> UndefinedVars(const LoweredFunc& f);
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);

/*!
* \brief Split the function into a host function and device functions.
Expand Down
34 changes: 27 additions & 7 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,18 @@ TVM_DLL int TVMModPreCompile(TVMModuleHandle mod,
const char* func_name,
TVMContext ctx);

/*!
* \brief Free the Module
* \param mod The module to be freed.
*
* \note This may not free up the module's resources.
* If there is active TVMFunctionHandle uses the module
* Or if this module is imported by another active module.
*
* The all functions remains valid until TVMFuncFree is called.
*/
TVM_DLL int TVMModFree(TVMModuleHandle mod);

/*!
* \brief Backend function for modules to get function
* from its environment mod_node (its imports and global function).
Expand All @@ -242,17 +254,25 @@ TVM_DLL int TVMModPreCompile(TVMModuleHandle mod,
TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *out);

/*!
* \brief Free the Module
* \param mod The module to be freed.
* \brief Backend function for running parallel for loop.
*
* \note This may not free up the module's resources.
* If there is active TVMFunctionHandle uses the module
* Or if this module is imported by another active module.
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* The all functions remains valid until TVMFuncFree is called.
* \param begin The start of iteration.
* \param end The end of iteration.
* \param lambda The lambda function to be executed.
* \param env The environment of lambda function.
*
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMModFree(TVMModuleHandle mod);
TVM_DLL int TVMBackendParallelFor(
int64_t begin,
int64_t end,
int (*lambda)(int64_t begin, int64_t end, void* env),
void* env);

/*!
* \brief Free the function when it is no longer needed.
Expand Down
9 changes: 8 additions & 1 deletion include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ enum AttachType : int {
/*! \brief IterVar type */
enum IterVarType : int {
kUnrolled = 1,
kVectorized = 2
kVectorized = 2,
kParallel = 3
};

/*! \brief Stage, contains scheduling for a stage of computation. */
Expand Down Expand Up @@ -152,6 +153,12 @@ class Stage : public NodeRef {
* \return reference to self.
*/
Stage& unroll(IterVar var); // NOLINT(*)
/*!
* \brief Parallelize iteration.
* \param var The axis to be parallelized.
* \return reference to self.
*/
Stage& parallel(IterVar var); // NOLINT(*)
/*!
* \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled.
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,13 @@ def unroll(self, var):
The iteration to be unrolled.
"""
_api_internal._StageUnroll(self, var)

def parallel(self, var):
"""Parallelize the iteration.

Parameters
----------
var : IterVar
The iteration to be parallelized.
"""
_api_internal._StageParallel(self, var)
6 changes: 6 additions & 0 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,12 @@ TVM_REGISTER_API(_StageVectorize)
.vectorize(args[1]);
});

TVM_REGISTER_API(_StageParallel)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.parallel(args[1]);
});

TVM_REGISTER_API(_ScheduleNormalize)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Schedule()
Expand Down
143 changes: 111 additions & 32 deletions src/codegen/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#ifdef TVM_LLVM_VERSION

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/ir_pass.h>
#include "./codegen_llvm.h"
#include "../../arithmetic/compute_expr.h"

Expand All @@ -30,6 +31,7 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_int8_ = llvm::Type::getInt8Ty(*ctx);
t_int16_ = llvm::Type::getInt16Ty(*ctx);
t_int32_ = llvm::Type::getInt32Ty(*ctx);
t_int64_ = llvm::Type::getInt64Ty(*ctx);
t_float64_ = llvm::Type::getDoubleTy(*ctx);
t_tvm_index_ = llvm::Type::getIntNTy(*ctx, sizeof(tvm_index_t) * 8);
t_tvm_context_ = llvm::StructType::create({t_int_, t_int_});
Expand All @@ -43,6 +45,8 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_tvm_type_,
t_tvm_context_});
t_tvm_value_ = llvm::StructType::create({t_float64_});
t_f_tvm_par_for_lambda_ = llvm::FunctionType::get(
t_int_, {t_int64_, t_int64_, t_void_p_}, false);
md_builder_.reset(new llvm::MDBuilder(*ctx));
md_very_likely_branch_ =
md_builder_->createBranchWeights(1 << 30, 0);
Expand Down Expand Up @@ -70,7 +74,11 @@ void CodeGenLLVM::Init(const std::string& module_name,
f_tvm_api_set_last_error_ = llvm::Function::Create(
llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false),
llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());

f_tvm_parallel_for_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {
t_int64_, t_int64_, t_f_tvm_par_for_lambda_->getPointerTo(), t_void_p_}
, false),
llvm::Function::ExternalLinkage, "TVMBackendParallelFor", module_.get());
this->InitTarget(target_triple);
// initialize builder
builder_.reset(new IRBuilder(*ctx));
Expand Down Expand Up @@ -141,7 +149,9 @@ void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
}
llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(block);
builder_->CreateRet(builder_->CreateCall(f, args));
llvm::CallInst* call = builder_->CreateCall(f, args);
call->setTailCall(true);
builder_->CreateRet(call);
}

class FPassManager : public llvm::legacy::FunctionPassManager {
Expand Down Expand Up @@ -545,7 +555,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
return nullptr;
}

llvm::BasicBlock* CodeGenLLVM::CheckPackedCallSuccess(llvm::Value* retcode) {
llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) {
// create emit codes that checks and load the function.
using llvm::BasicBlock;
BasicBlock* fail_block = BasicBlock::Create(
Expand All @@ -563,34 +573,15 @@ llvm::BasicBlock* CodeGenLLVM::CheckPackedCallSuccess(llvm::Value* retcode) {
return end_block;
}
void CodeGenLLVM::Visit_(const For* op) {
using llvm::BasicBlock;
BasicBlock* for_head = BasicBlock::Create(
*ctx_, "for_head", function_);
BasicBlock* for_body = BasicBlock::Create(
*ctx_, "for_body", function_);
BasicBlock* for_end = BasicBlock::Create(
*ctx_, "for_end", function_);
BasicBlock* pre_block = builder_->GetInsertBlock();
CHECK(is_zero(op->min));
Type t = op->min.type();
llvm::Value* init = ConstInt32(0);
llvm::Value* extent = MakeValue(op->extent);
builder_->CreateBr(for_head);

builder_->SetInsertPoint(for_head);
llvm::PHINode* index = builder_->CreatePHI(LLVMType(t), 2);
index->addIncoming(init, pre_block);
llvm::Value* cond = CreateLT(t, index, extent);
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_);
// body of for
builder_->SetInsertPoint(for_body);
var_map_[op->loop_var.get()] = index;
this->Visit(op->body);
llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1));
index->addIncoming(next_index, builder_->GetInsertBlock());
builder_->CreateBr(for_head);
// end of for
builder_->SetInsertPoint(for_end);
if (op->for_type == ForType::Serial) {
CreateSerialFor(ConstInt32(0), MakeValue(op->extent),
op->loop_var, op->body);
} else if (op->for_type == ForType::Parallel) {
CreateParallelFor(op);
} else {
LOG(FATAL) << "cannot handle for type " << op->for_type;
}
}

void CodeGenLLVM::Visit_(const IfThenElse* op) {
Expand Down Expand Up @@ -807,7 +798,7 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
llvm::Value* ctx = builder_->CreateLoad(gv_mod_ctx_);
llvm::Value* retcode = builder_->CreateCall(
f_tvm_get_func_from_env_, {ctx, GetConstString(fname), out});
init_block = CheckPackedCallSuccess(retcode);
init_block = CheckCallSuccess(retcode);
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
builder_->CreateBr(end_block);
// end block
Expand Down Expand Up @@ -846,7 +837,7 @@ llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
}
llvm::Value* ret_value = builder_->CreateAlloca(t_tvm_value_);
llvm::Value* ret_tcode = builder_->CreateAlloca(t_int_);
CheckPackedCallSuccess(
CheckCallSuccess(
builder_->CreateCall(
f_tvm_func_call_,
{handle, targs, tcodes, ConstInt32(nargs), ret_value, ret_tcode}));
Expand Down Expand Up @@ -934,6 +925,94 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
}
}

void CodeGenLLVM::CreateParallelFor(const For* op) {
using llvm::BasicBlock;
llvm::Value* min = MakeValue(op->min);
llvm::Value* extent = MakeValue(op->extent);
min = builder_->CreateIntCast(min, t_int64_, op->min.type().is_int());
extent = builder_->CreateIntCast(extent, t_int64_, op->min.type().is_int());
// fields to be packed into closure.
Var loop_var(op->loop_var.node_);
Array<Var> vfields = ir::UndefinedVars(op->body, {loop_var});
std::vector<llvm::Type*> fields;
for (Var v : vfields) {
auto it = var_map_.find(v.get());
CHECK(it != var_map_.end());
fields.push_back(it->second->getType());
}
// closure data
llvm::StructType* tcdata = llvm::StructType::create(fields);
llvm::Function* f = llvm::Function::Create(
t_f_tvm_par_for_lambda_,
llvm::Function::PrivateLinkage,
"__tvm_par_for_lambda", module_.get());
// allocate and setup the closure, call the closure.
llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1));
llvm::Value* zero = ConstInt32(0);

for (size_t i = 0; i < vfields.size(); ++i) {
builder_->CreateStore(
var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
}
BasicBlock* par_for_end = CheckCallSuccess(
builder_->CreateCall(
f_tvm_parallel_for_,
{min, extent, f, builder_->CreatePointerCast(cdata, t_void_p_)}));
// Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
llvm::Value* begin = &(*it++);
llvm::Value* end = &(*it++);
cdata = &(*it++);
begin = CreateCast(Int(64), op->loop_var.type(), begin);
end = CreateCast(Int(64), op->loop_var.type(), end);
cdata = builder_->CreatePointerCast(cdata, tcdata->getPointerTo());
// setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
for (size_t i = 0; i < vfields.size(); ++i) {
new_vmap[vfields[i].get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(
cdata, {zero, ConstInt32(i)}));
}
std::swap(function_, f);
std::swap(new_vmap, var_map_);
CreateSerialFor(begin, end, op->loop_var, op->body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(new_vmap, var_map_);
std::swap(function_, f);
builder_->SetInsertPoint(par_for_end);
}

void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
const VarExpr& loop_var, const Stmt& body) {
using llvm::BasicBlock;
Type t = loop_var.type();
BasicBlock* for_head = BasicBlock::Create(
*ctx_, "for_head", function_);
BasicBlock* for_body = BasicBlock::Create(
*ctx_, "for_body", function_);
BasicBlock* for_end = BasicBlock::Create(
*ctx_, "for_end", function_);
BasicBlock* pre_block = builder_->GetInsertBlock();
builder_->CreateBr(for_head);
builder_->SetInsertPoint(for_head);
llvm::PHINode* index = builder_->CreatePHI(begin->getType(), 2);
index->addIncoming(begin, pre_block);
llvm::Value* cond = CreateLT(t, index, end);
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_);
// body of for
builder_->SetInsertPoint(for_body);
var_map_[loop_var.get()] = index;
this->Visit(body);
llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1));
index->addIncoming(next_index, builder_->GetInsertBlock());
builder_->CreateBr(for_head);
// end of for
builder_->SetInsertPoint(for_end);
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
9 changes: 8 additions & 1 deletion src/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,12 @@ class CodeGenLLVM : public IRVisitor {
llvm::StructType* t_tvm_type_{nullptr};
llvm::StructType* t_tvm_array_{nullptr};
llvm::StructType* t_tvm_value_{nullptr};
llvm::FunctionType* t_f_tvm_par_for_lambda_{nullptr};
// tvm api functions
llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_parallel_for_{nullptr};
// The acting body
llvm::BasicBlock* block_{nullptr};
// Last value returned codegen call.
Expand All @@ -176,10 +178,15 @@ class CodeGenLLVM : public IRVisitor {
llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
llvm::Value* GetPackedFuncHandle(const std::string& str);
// Create parallel for.
void CreateParallelFor(const For* op);
// Create serial for
void CreateSerialFor(llvm::Value* begin, llvm::Value* end,
const VarExpr& loop_var, const Stmt& body);
// Check if the call to packed function is successful
// if not directly finalize function and pass on return code.
// return the end block after the check
llvm::BasicBlock* CheckPackedCallSuccess(llvm::Value* retcode);
llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
// Initialize target
void InitTarget(const std::string& target);
// Add a function to set global module context
Expand Down
16 changes: 16 additions & 0 deletions src/lang/lowered_func.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*!
* Copyright (c) 2017 by Contributors
* \file lowered_func.cc
*/
#include <tvm/lowered_func.h>

namespace tvm {

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<LoweredFuncNode>([](const LoweredFuncNode *op, IRPrinter *p) {
p->stream << "LoweredFunc(" << op->name << ", " << op << ")";
});

TVM_REGISTER_NODE_TYPE(LoweredFuncNode);

} // namespace tvm
2 changes: 1 addition & 1 deletion src/pass/make_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ LoweredFunc MakeAPI(Stmt body,
n->is_packed_func = num_unpacked_args == 0;
n->body = MergeNest({seq_init, seq_check}, body);
LoweredFunc f(n);
Array<Var> undefined = UndefinedVars(f);
Array<Var> undefined = UndefinedVars(f->body, f->args);
if (undefined.size() != 0) {
std::ostringstream os;
for (Var v : undefined) {
Expand Down
Loading