From 8e6c30907c7c5aa020fba42410d91e8eaa5ebfe6 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 24 Feb 2017 21:00:04 -0800 Subject: [PATCH] [LLVM/RUNTIME] Support Parallel for on CPU --- include/tvm/ir_pass.h | 9 +- include/tvm/runtime/c_runtime_api.h | 34 ++++- include/tvm/schedule.h | 9 +- python/tvm/schedule.py | 10 ++ src/api/api_lang.cc | 6 + src/codegen/llvm/codegen_llvm.cc | 143 ++++++++++++++---- src/codegen/llvm/codegen_llvm.h | 9 +- src/lang/lowered_func.cc | 16 ++ src/pass/make_api.cc | 2 +- src/pass/split_host_device.cc | 6 +- src/runtime/c_runtime_api.cc | 69 ++++++++- src/schedule/schedule_lang.cc | 6 + src/schedule/schedule_ops.cc | 1 + tests/python/unittest/test_codegen_llvm.py | 31 ++++ ...stack_llvm.py => test_codegen_vm_basic.py} | 33 ---- 15 files changed, 298 insertions(+), 86 deletions(-) create mode 100644 src/lang/lowered_func.cc create mode 100644 tests/python/unittest/test_codegen_llvm.py rename tests/python/unittest/{test_codegen_stack_llvm.py => test_codegen_vm_basic.py} (66%) diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index c3197b5c1bde..542ec34424cd 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -173,11 +173,12 @@ LoweredFunc MakeAPI(Stmt body, int num_unpacked_args); /*! - * \brief Count number of undefined vars in f. - * \param f The function to be checked. - * \return Number of undefined vars. + * \brief Find undefined vars in the statment. + * \param stmt The function to be checked. + * \param defs The vars that is defined. + * \return Array of undefined vars. */ -Array UndefinedVars(const LoweredFunc& f); +Array UndefinedVars(const Stmt& stmt, const Array& defs); /*! * \brief Split the function into a host function and device functions. diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 33c46b40d670..b61d4b76333e 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -225,6 +225,18 @@ TVM_DLL int TVMModPreCompile(TVMModuleHandle mod, const char* func_name, TVMContext ctx); +/*! + * \brief Free the Module + * \param mod The module to be freed. + * + * \note This may not free up the module's resources. + * If there is active TVMFunctionHandle uses the module + * Or if this module is imported by another active module. + * + * The all functions remains valid until TVMFuncFree is called. + */ +TVM_DLL int TVMModFree(TVMModuleHandle mod); + /*! * \brief Backend function for modules to get function * from its environment mod_node (its imports and global function). @@ -242,17 +254,25 @@ TVM_DLL int TVMModPreCompile(TVMModuleHandle mod, TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle *out); + /*! - * \brief Free the Module - * \param mod The module to be freed. + * \brief Backend function for running parallel for loop. * - * \note This may not free up the module's resources. - * If there is active TVMFunctionHandle uses the module - * Or if this module is imported by another active module. + * \note This API is supposed to be used by backend, + * it is not supposed to be used by user. * - * The all functions remains valid until TVMFuncFree is called. + * \param begin The start of iteration. + * \param end The end of iteration. + * \param lambda The lambda function to be executed. + * \param env The environment of lambda function. + * + * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMModFree(TVMModuleHandle mod); +TVM_DLL int TVMBackendParallelFor( + int64_t begin, + int64_t end, + int (*lambda)(int64_t begin, int64_t end, void* env), + void* env); /*! * \brief Free the function when it is no longer needed. diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index ce0cd3420d69..1ce55622abea 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -34,7 +34,8 @@ enum AttachType : int { /*! \brief IterVar type */ enum IterVarType : int { kUnrolled = 1, - kVectorized = 2 + kVectorized = 2, + kParallel = 3 }; /*! \brief Stage, contains scheduling for a stage of computation. */ @@ -152,6 +153,12 @@ class Stage : public NodeRef { * \return reference to self. */ Stage& unroll(IterVar var); // NOLINT(*) + /*! + * \brief Parallelize iteration. + * \param var The axis to be parallelized. + * \return reference to self. + */ + Stage& parallel(IterVar var); // NOLINT(*) /*! * \brief whether the stage has been scheduled. * \return whether the stage has been scheduled. diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index f0db2562d372..b54191b6502e 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -257,3 +257,13 @@ def unroll(self, var): The iteration to be unrolled. """ _api_internal._StageUnroll(self, var) + + def parallel(self, var): + """Parallelize the iteration. + + Parameters + ---------- + var : IterVar + The iteration to be parallelized. + """ + _api_internal._StageParallel(self, var) diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index ea49bbae18cf..23d9651e1733 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -280,6 +280,12 @@ TVM_REGISTER_API(_StageVectorize) .vectorize(args[1]); }); +TVM_REGISTER_API(_StageParallel) + .set_body([](TVMArgs args, TVMRetValue* ret) { + args[0].operator Stage() + .parallel(args[1]); + }); + TVM_REGISTER_API(_ScheduleNormalize) .set_body([](TVMArgs args, TVMRetValue* ret) { args[0].operator Schedule() diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 232ef4923aae..fbe455c691ea 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -5,6 +5,7 @@ #ifdef TVM_LLVM_VERSION #include +#include #include "./codegen_llvm.h" #include "../../arithmetic/compute_expr.h" @@ -30,6 +31,7 @@ void CodeGenLLVM::Init(const std::string& module_name, t_int8_ = llvm::Type::getInt8Ty(*ctx); t_int16_ = llvm::Type::getInt16Ty(*ctx); t_int32_ = llvm::Type::getInt32Ty(*ctx); + t_int64_ = llvm::Type::getInt64Ty(*ctx); t_float64_ = llvm::Type::getDoubleTy(*ctx); t_tvm_index_ = llvm::Type::getIntNTy(*ctx, sizeof(tvm_index_t) * 8); t_tvm_context_ = llvm::StructType::create({t_int_, t_int_}); @@ -43,6 +45,8 @@ void CodeGenLLVM::Init(const std::string& module_name, t_tvm_type_, t_tvm_context_}); t_tvm_value_ = llvm::StructType::create({t_float64_}); + t_f_tvm_par_for_lambda_ = llvm::FunctionType::get( + t_int_, {t_int64_, t_int64_, t_void_p_}, false); md_builder_.reset(new llvm::MDBuilder(*ctx)); md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 30, 0); @@ -70,7 +74,11 @@ void CodeGenLLVM::Init(const std::string& module_name, f_tvm_api_set_last_error_ = llvm::Function::Create( llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false), llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get()); - + f_tvm_parallel_for_ = llvm::Function::Create( + llvm::FunctionType::get(t_int_, { + t_int64_, t_int64_, t_f_tvm_par_for_lambda_->getPointerTo(), t_void_p_} + , false), + llvm::Function::ExternalLinkage, "TVMBackendParallelFor", module_.get()); this->InitTarget(target_triple); // initialize builder builder_.reset(new IRBuilder(*ctx)); @@ -141,7 +149,9 @@ void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) { } llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_); builder_->SetInsertPoint(block); - builder_->CreateRet(builder_->CreateCall(f, args)); + llvm::CallInst* call = builder_->CreateCall(f, args); + call->setTailCall(true); + builder_->CreateRet(call); } class FPassManager : public llvm::legacy::FunctionPassManager { @@ -545,7 +555,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { return nullptr; } -llvm::BasicBlock* CodeGenLLVM::CheckPackedCallSuccess(llvm::Value* retcode) { +llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) { // create emit codes that checks and load the function. using llvm::BasicBlock; BasicBlock* fail_block = BasicBlock::Create( @@ -563,34 +573,15 @@ llvm::BasicBlock* CodeGenLLVM::CheckPackedCallSuccess(llvm::Value* retcode) { return end_block; } void CodeGenLLVM::Visit_(const For* op) { - using llvm::BasicBlock; - BasicBlock* for_head = BasicBlock::Create( - *ctx_, "for_head", function_); - BasicBlock* for_body = BasicBlock::Create( - *ctx_, "for_body", function_); - BasicBlock* for_end = BasicBlock::Create( - *ctx_, "for_end", function_); - BasicBlock* pre_block = builder_->GetInsertBlock(); CHECK(is_zero(op->min)); - Type t = op->min.type(); - llvm::Value* init = ConstInt32(0); - llvm::Value* extent = MakeValue(op->extent); - builder_->CreateBr(for_head); - - builder_->SetInsertPoint(for_head); - llvm::PHINode* index = builder_->CreatePHI(LLVMType(t), 2); - index->addIncoming(init, pre_block); - llvm::Value* cond = CreateLT(t, index, extent); - builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_); - // body of for - builder_->SetInsertPoint(for_body); - var_map_[op->loop_var.get()] = index; - this->Visit(op->body); - llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1)); - index->addIncoming(next_index, builder_->GetInsertBlock()); - builder_->CreateBr(for_head); - // end of for - builder_->SetInsertPoint(for_end); + if (op->for_type == ForType::Serial) { + CreateSerialFor(ConstInt32(0), MakeValue(op->extent), + op->loop_var, op->body); + } else if (op->for_type == ForType::Parallel) { + CreateParallelFor(op); + } else { + LOG(FATAL) << "cannot handle for type " << op->for_type; + } } void CodeGenLLVM::Visit_(const IfThenElse* op) { @@ -807,7 +798,7 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) { llvm::Value* ctx = builder_->CreateLoad(gv_mod_ctx_); llvm::Value* retcode = builder_->CreateCall( f_tvm_get_func_from_env_, {ctx, GetConstString(fname), out}); - init_block = CheckPackedCallSuccess(retcode); + init_block = CheckCallSuccess(retcode); llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align); builder_->CreateBr(end_block); // end block @@ -846,7 +837,7 @@ llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) { } llvm::Value* ret_value = builder_->CreateAlloca(t_tvm_value_); llvm::Value* ret_tcode = builder_->CreateAlloca(t_int_); - CheckPackedCallSuccess( + CheckCallSuccess( builder_->CreateCall( f_tvm_func_call_, {handle, targs, tcodes, ConstInt32(nargs), ret_value, ret_tcode})); @@ -934,6 +925,94 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) { } } +void CodeGenLLVM::CreateParallelFor(const For* op) { + using llvm::BasicBlock; + llvm::Value* min = MakeValue(op->min); + llvm::Value* extent = MakeValue(op->extent); + min = builder_->CreateIntCast(min, t_int64_, op->min.type().is_int()); + extent = builder_->CreateIntCast(extent, t_int64_, op->min.type().is_int()); + // fields to be packed into closure. + Var loop_var(op->loop_var.node_); + Array vfields = ir::UndefinedVars(op->body, {loop_var}); + std::vector fields; + for (Var v : vfields) { + auto it = var_map_.find(v.get()); + CHECK(it != var_map_.end()); + fields.push_back(it->second->getType()); + } + // closure data + llvm::StructType* tcdata = llvm::StructType::create(fields); + llvm::Function* f = llvm::Function::Create( + t_f_tvm_par_for_lambda_, + llvm::Function::PrivateLinkage, + "__tvm_par_for_lambda", module_.get()); + // allocate and setup the closure, call the closure. + llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1)); + llvm::Value* zero = ConstInt32(0); + + for (size_t i = 0; i < vfields.size(); ++i) { + builder_->CreateStore( + var_map_.at(vfields[i].get()), + builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)})); + } + BasicBlock* par_for_end = CheckCallSuccess( + builder_->CreateCall( + f_tvm_parallel_for_, + {min, extent, f, builder_->CreatePointerCast(cdata, t_void_p_)})); + // Setup the closure function. + BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f); + builder_->SetInsertPoint(lambda_entry); + auto it = f->arg_begin(); + llvm::Value* begin = &(*it++); + llvm::Value* end = &(*it++); + cdata = &(*it++); + begin = CreateCast(Int(64), op->loop_var.type(), begin); + end = CreateCast(Int(64), op->loop_var.type(), end); + cdata = builder_->CreatePointerCast(cdata, tcdata->getPointerTo()); + // setup new variable map, swap it with current var context. + std::unordered_map new_vmap; + for (size_t i = 0; i < vfields.size(); ++i) { + new_vmap[vfields[i].get()] = + builder_->CreateLoad(builder_->CreateInBoundsGEP( + cdata, {zero, ConstInt32(i)})); + } + std::swap(function_, f); + std::swap(new_vmap, var_map_); + CreateSerialFor(begin, end, op->loop_var, op->body); + builder_->CreateRet(ConstInt32(0)); + // swap the var map back, now we are back on track. + std::swap(new_vmap, var_map_); + std::swap(function_, f); + builder_->SetInsertPoint(par_for_end); +} + +void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, + const VarExpr& loop_var, const Stmt& body) { + using llvm::BasicBlock; + Type t = loop_var.type(); + BasicBlock* for_head = BasicBlock::Create( + *ctx_, "for_head", function_); + BasicBlock* for_body = BasicBlock::Create( + *ctx_, "for_body", function_); + BasicBlock* for_end = BasicBlock::Create( + *ctx_, "for_end", function_); + BasicBlock* pre_block = builder_->GetInsertBlock(); + builder_->CreateBr(for_head); + builder_->SetInsertPoint(for_head); + llvm::PHINode* index = builder_->CreatePHI(begin->getType(), 2); + index->addIncoming(begin, pre_block); + llvm::Value* cond = CreateLT(t, index, end); + builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_); + // body of for + builder_->SetInsertPoint(for_body); + var_map_[loop_var.get()] = index; + this->Visit(body); + llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1)); + index->addIncoming(next_index, builder_->GetInsertBlock()); + builder_->CreateBr(for_head); + // end of for + builder_->SetInsertPoint(for_end); +} } // namespace codegen } // namespace tvm #endif // TVM_LLVM_VERSION diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 36fa0389ad45..3f7c197c270d 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -152,10 +152,12 @@ class CodeGenLLVM : public IRVisitor { llvm::StructType* t_tvm_type_{nullptr}; llvm::StructType* t_tvm_array_{nullptr}; llvm::StructType* t_tvm_value_{nullptr}; + llvm::FunctionType* t_f_tvm_par_for_lambda_{nullptr}; // tvm api functions llvm::Function* f_tvm_func_call_{nullptr}; llvm::Function* f_tvm_get_func_from_env_{nullptr}; llvm::Function* f_tvm_api_set_last_error_{nullptr}; + llvm::Function* f_tvm_parallel_for_{nullptr}; // The acting body llvm::BasicBlock* block_{nullptr}; // Last value returned codegen call. @@ -176,10 +178,15 @@ class CodeGenLLVM : public IRVisitor { llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index); llvm::Value* CreateCast(Type from, Type to, llvm::Value* value); llvm::Value* GetPackedFuncHandle(const std::string& str); + // Create parallel for. + void CreateParallelFor(const For* op); + // Create serial for + void CreateSerialFor(llvm::Value* begin, llvm::Value* end, + const VarExpr& loop_var, const Stmt& body); // Check if the call to packed function is successful // if not directly finalize function and pass on return code. // return the end block after the check - llvm::BasicBlock* CheckPackedCallSuccess(llvm::Value* retcode); + llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode); // Initialize target void InitTarget(const std::string& target); // Add a function to set global module context diff --git a/src/lang/lowered_func.cc b/src/lang/lowered_func.cc new file mode 100644 index 000000000000..c199e7faa65d --- /dev/null +++ b/src/lang/lowered_func.cc @@ -0,0 +1,16 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file lowered_func.cc + */ +#include + +namespace tvm { + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const LoweredFuncNode *op, IRPrinter *p) { + p->stream << "LoweredFunc(" << op->name << ", " << op << ")"; +}); + +TVM_REGISTER_NODE_TYPE(LoweredFuncNode); + +} // namespace tvm diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index a9bff236b163..e4a2e76267bf 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -188,7 +188,7 @@ LoweredFunc MakeAPI(Stmt body, n->is_packed_func = num_unpacked_args == 0; n->body = MergeNest({seq_init, seq_check}, body); LoweredFunc f(n); - Array undefined = UndefinedVars(f); + Array undefined = UndefinedVars(f->body, f->args); if (undefined.size() != 0) { std::ostringstream os; for (Var v : undefined) { diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc index c832b726ffce..642c1ed12fd0 100644 --- a/src/pass/split_host_device.cc +++ b/src/pass/split_host_device.cc @@ -220,12 +220,12 @@ class HostDeviceSplitter : public IRMutator { }; -Array UndefinedVars(const LoweredFunc& f) { +Array UndefinedVars(const Stmt& stmt, const Array& args) { IRUseDefAnalysis m; - for (Var arg : f->args) { + for (Var arg : args) { m.use_count_[arg.get()] = 0; } - m.Mutate(f->body); + m.Mutate(stmt); return m.undefined_; } diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 3ffe02ed518e..837e884dd6e3 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -7,8 +7,11 @@ #include #include #include +#include #include #include +#include +#include #include "./runtime_base.h" #include "./device_api.h" @@ -71,6 +74,24 @@ using namespace tvm::runtime; struct TVMRuntimeEntry { std::string ret_str; std::string last_error; + // threads used in parallel for + std::vector par_threads; + // errors created in parallel for. + std::vector par_errors; + // number of parallel threads + int num_par_threads{1}; + + TVMRuntimeEntry() { + const char *val = getenv("TVM_NUM_THREADS"); + if (val == nullptr) { + val = getenv("OMP_NUM_THREADS"); + } + if (val != nullptr) { + num_par_threads = atoi(val); + } else { + num_par_threads = std::thread::hardware_concurrency(); + } + } }; typedef dmlc::ThreadLocalStore TVMAPIRuntimeStore; @@ -123,6 +144,12 @@ int TVMModPreCompile(TVMModuleHandle mod, API_END(); } +int TVMModFree(TVMModuleHandle mod) { + API_BEGIN(); + delete static_cast(mod); + API_END(); +} + int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle *func) { @@ -132,10 +159,44 @@ int TVMBackendGetFuncFromEnv(void* mod_node, API_END(); } -int TVMModFree(TVMModuleHandle mod) { - API_BEGIN(); - delete static_cast(mod); - API_END(); +int TVMBackendParallelFor( + int64_t begin, + int64_t end, + int (*lambda)(int64_t begin, int64_t end, void* env), + void* env) { + TVMRuntimeEntry* rt = TVMAPIRuntimeStore::Get(); + int nthread = rt->num_par_threads; + rt->par_threads.resize(nthread); + rt->par_errors.clear(); + rt->par_errors.resize(nthread); + int64_t step = (end - begin + nthread - 1) / nthread; + auto fexec = [lambda, env, begin, end, step, rt](int i) { + int64_t ibegin = std::min(end, begin + step * i); + int64_t iend = std::min(end, begin + step * (i + 1)); + int rv = (*lambda)(ibegin, iend, env); + if (rv != 0) { + std::ostringstream os; + os << "Thread " << i << " error:" << TVMGetLastError(); + rt->par_errors[i] = os.str(); + } + }; + for (int i = 0; i < nthread; ++i) { + rt->par_threads[i] = std::thread(fexec, i); + } + int ret = 0; + for (int i = 0; i < nthread; ++i) { + rt->par_threads[i].join(); + if (rt->par_errors[i].length() != 0) ret = -1; + } + if (ret == 0) return ret; + std::ostringstream os; + for (int i = 0; i < nthread; ++i) { + if (rt->par_errors[i].length() != 0) { + os << rt->par_errors[i] << '\n'; + } + } + rt->last_error = os.str(); + return -1; } int TVMFuncFree(TVMFunctionHandle func) { diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index 308070a8b702..c384d465f69c 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -69,6 +69,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) switch (op->iter_type) { case kUnrolled: p->stream << "unroll"; break; case kVectorized: p->stream << "vectorize"; break; + case kParallel: p->stream << "parallel"; break; } }); @@ -246,6 +247,11 @@ Stage& Stage::unroll(IterVar var) { // NOLINT(*) return *this; } +Stage& Stage::parallel(IterVar var) { // NOLINT(*) + SetAttr(operator->(), var, IterVarAttr(kParallel)); + return *this; +} + Schedule::Schedule(Array ops) { auto n = std::make_shared(); n->outputs = ops; diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index 11b6d354dfab..f2f50750c145 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -189,6 +189,7 @@ MakeLoopNest(const Stage& sch, if (sch->iter_var_attrs.count(iv)) { switch (sch->iter_var_attrs[iv]->iter_type) { case kUnrolled: for_type = ForType::Unrolled; break; + case kParallel: for_type = ForType::Parallel; break; case kVectorized: for_type = ForType::Vectorized; break; } } diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py new file mode 100644 index 000000000000..fed6cb6f283c --- /dev/null +++ b/tests/python/unittest/test_codegen_llvm.py @@ -0,0 +1,31 @@ +import tvm +import numpy as np + +def test_llvm_add_pipeline(): + n = tvm.Var('n') + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.Schedule(C.op) + s[C].parallel(C.op.axis[0]) + + def check_llvm(): + if not tvm.codegen.enabled("llvm"): + return + # build and invoke the kernel. + f = tvm.build(s, [A, B, C], "llvm") + ctx = tvm.cpu(0) + # launch the kernel. + n = 10270 * 2460 + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) + for i in range(1000): + f(a, b, c) + np.testing.assert_allclose( + c.asnumpy(), a.asnumpy() + b.asnumpy()) + check_llvm() + + +if __name__ == "__main__": + test_llvm_add_pipeline() diff --git a/tests/python/unittest/test_codegen_stack_llvm.py b/tests/python/unittest/test_codegen_vm_basic.py similarity index 66% rename from tests/python/unittest/test_codegen_stack_llvm.py rename to tests/python/unittest/test_codegen_vm_basic.py index caaa056baa01..13d42aa8d638 100644 --- a/tests/python/unittest/test_codegen_stack_llvm.py +++ b/tests/python/unittest/test_codegen_vm_basic.py @@ -78,40 +78,7 @@ def check(f): np.testing.assert_equal(a.asnumpy(), y) run_jit(fapi, check) - -def test_llvm_add_pipeline(): - n = tvm.Var('n') - A = tvm.placeholder((n,), name='A') - B = tvm.placeholder((n,), name='B') - C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') - s = tvm.Schedule(C.op) - bounds = tvm.schedule.InferBound(s) - stmt = tvm.schedule.ScheduleOps(s, bounds) - Ab = tvm.Buffer(A.shape, A.dtype, name='A') - Bb = tvm.Buffer(B.shape, B.dtype, name='B') - Cb = tvm.Buffer(C.shape, C.dtype, name='C') - stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) - stmt = tvm.ir_pass.Simplify(stmt) - fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0) - - def check_llvm(): - if not tvm.codegen.enabled("llvm"): - return - # build and invoke the kernel. - f = tvm.codegen.build(fapi, "llvm") - ctx = tvm.cpu(0) - # launch the kernel. - n = 1027 - a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx) - c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx) - f(a, b, c) - np.testing.assert_allclose( - c.asnumpy(), a.asnumpy() + b.asnumpy()) - check_llvm() - if __name__ == "__main__": test_stack_vm_basic() test_stack_vm_cond() test_stack_vm_loop() - test_llvm_add_pipeline()