From 45597d00061c155f7844b873149d36a48467c959 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 9 Feb 2017 12:15:47 -0800 Subject: [PATCH] [LANG/PASS] Support Vectorize (#37) --- include/tvm/ir_mutator.h | 1 + include/tvm/ir_pass.h | 6 + include/tvm/schedule.h | 57 ++- python/tvm/build.py | 1 + python/tvm/schedule.py | 20 + src/api/api_lang.cc | 12 + src/api/api_pass.cc | 1 + src/arithmetic/compute_expr.h | 18 + src/codegen/codegen_c.cc | 294 ++++++++++---- src/codegen/codegen_c.h | 67 ++-- src/codegen/codegen_cuda.cc | 105 ++++- src/codegen/codegen_cuda.h | 9 + src/codegen/codegen_opencl.cc | 87 ++++- src/codegen/codegen_opencl.h | 10 + src/pass/ir_mutator.cc | 32 +- src/pass/split_host_device.cc | 8 +- src/pass/unroll_loop.cc | 11 +- src/pass/vectorize_loop.cc | 385 +++++++++++++++++++ src/schedule/schedule_lang.cc | 39 ++ src/schedule/schedule_ops.cc | 11 +- tests/python/integration/test_ewise.py | 38 +- tests/python/unittest/test_lang_schedule.py | 16 + tests/python/unittest/test_pass_unroll.py | 6 +- tests/python/unittest/test_pass_vectorize.py | 24 ++ 24 files changed, 1119 insertions(+), 139 deletions(-) create mode 100644 src/pass/vectorize_loop.cc create mode 100644 tests/python/unittest/test_pass_vectorize.py diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index b57bca25eb49..649ea0239ef1 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -62,6 +62,7 @@ class IRMutator { virtual Stmt Mutate_(const Realize* op, const Stmt& s); virtual Stmt Mutate_(const Store* op, const Stmt& s); virtual Stmt Mutate_(const Free* op, const Stmt& s); + virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s); virtual Expr Mutate_(const Call* op, const Expr& e); virtual Expr Mutate_(const Load* op, const Expr& s); virtual Expr Mutate_(const Variable* op, const Expr& e); diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 9e3e1b0a1d53..4ce90e3b7739 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -111,6 +111,12 @@ Stmt StorageFlatten(Stmt stmt, */ Stmt UnrollLoop(Stmt stmt, int max_auto_step); +/*! + * \brief vectorize the constant loops + * \param stmt The statment to be vectorized. + */ +Stmt VectorizeLoop(Stmt stmt); + /*! * \brief Make an user callable API LoweredFunc. * diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index a7cd58c96524..b8c24903f88f 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -18,6 +18,8 @@ class StageNode; class ScheduleNode; // Node container for IterVarRelation class IterVarRelationNode; +// Attribute of itervar. +class IterVarAttrNode; /*! \brief the attachment type */ enum AttachType : int { @@ -27,6 +29,12 @@ enum AttachType : int { kScope = 3 }; +/*! \brief IterVar type */ +enum IterVarType : int { + kUnrolled = 1, + kVectorized = 2 +}; + /*! \brief Stage, contains scheduling for a stage of computation. */ class Stage : public NodeRef { public: @@ -123,12 +131,23 @@ class Stage : public NodeRef { IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner, Expr x_factor, Expr y_factor); + /*! + * \brief Vectorize iteration. + * \param var The axis to be vectorized. + * \return reference to self. + */ + Stage& vectorize(IterVar var); // NOLINT(*) + /*! + * \brief Unroll iteration. + * \param var The axis to be vectorized. + * \return reference to self. + */ + Stage& unroll(IterVar var); // NOLINT(*) /*! * \brief whether the stage has been scheduled. * \return whether the stage has been scheduled. */ inline bool is_scheduled() const; - // declare container type using ContainerType = StageNode; }; @@ -193,6 +212,21 @@ class IterVarRelation : public NodeRef { inline const IterVarRelationNode* operator->() const; }; +/*! + * \brief Additional scheduable attributes about IterVar. + */ +class IterVarAttr : public NodeRef { + public: + IterVarAttr() {} + explicit IterVarAttr(IterVarType t); + explicit IterVarAttr(std::shared_ptr n) : NodeRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const IterVarAttrNode* operator->() const; +}; + // defintion of node containers /*! * \brief represents the schedule of the tensor @@ -223,6 +257,8 @@ class StageNode : public Node { Array leaf_iter_vars; /*! \brief The relation bwteen of IterVars */ Array relations; + /*! \brief additional attributes about iter var. */ + Map iter_var_attrs; /*! \brief The attachment type of the schedule */ AttachType attach_type{kNone}; /*! \brief The attach point of this schedule. */ @@ -236,6 +272,7 @@ class StageNode : public Node { v->Visit("all_iter_vars", &all_iter_vars); v->Visit("leaf_iter_vars", &leaf_iter_vars); v->Visit("relations", &relations); + v->Visit("iter_var_attrs", &iter_var_attrs); v->Visit("attach_type", &attach_type); v->Visit("attach_ivar", &attach_ivar); v->Visit("attach_stage", &attach_stage); @@ -268,6 +305,20 @@ class ScheduleNode : public Node { TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode); }; +/*! \brief node container for IterVar attr */ +class IterVarAttrNode : public Node { + public: + /*! \brief The iteration type. */ + IterVarType iter_type; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("iter_type", &iter_type); + } + + static constexpr const char* _type_key = "IterVarAttr"; + TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode); +}; + /*! \brief base node of iteration var */ class IterVarRelationNode : public Node { }; @@ -372,5 +423,9 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const { return static_cast(node_.get()); } +inline const IterVarAttrNode* IterVarAttr::operator->() const { + return static_cast(node_.get()); +} + } // namespace tvm #endif // TVM_SCHEDULE_H_ diff --git a/python/tvm/build.py b/python/tvm/build.py index bb03e8395687..7ddabaf7631d 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -69,6 +69,7 @@ def build(sch, stmt = schedule.ScheduleOps(sch, bounds) stmt = ir_pass.StorageFlatten(stmt, binds) stmt = ir_pass.CanonicalSimplify(stmt) + stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step) stmt = ir_pass.Simplify(stmt) fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list)) diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index fee0fb3b1274..842b9a605425 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -177,3 +177,23 @@ def tile(self, x_parent, y_parent, x_factor, y_factor): x_outer, y_outer, x_inner, y_inner = _api_internal._StageTile( self, x_parent, y_parent, x_factor, y_factor) return x_outer, y_outer, x_inner, y_inner + + def vectorize(self, var): + """Vectorize the iteration. + + Parameters + ---------- + var : IterVar + The iteration to be vectorize + """ + _api_internal._StageVectorize(self, var) + + def unroll(self, var): + """Unroll the iteration. + + Parameters + ---------- + var : IterVar + The iteration to be unrolled. + """ + _api_internal._StageUnroll(self, var) diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 96c61a76227b..769345fc415e 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -253,6 +253,18 @@ TVM_REGISTER_API(_StageTile) *ret = Array({x_outer, y_outer, x_inner, y_inner}); }); +TVM_REGISTER_API(_StageUnroll) + .set_body([](TVMArgs args, TVMRetValue* ret) { + args[0].operator Stage() + .unroll(args[1]); + }); + +TVM_REGISTER_API(_StageVectorize) + .set_body([](TVMArgs args, TVMRetValue* ret) { + args[0].operator Stage() + .vectorize(args[1]); + }); + TVM_REGISTER_API(_ScheduleNormalize) .set_body([](TVMArgs args, TVMRetValue* ret) { args[0].operator Schedule() diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index ff67ac7a867b..b8f3cbc3bd9e 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -62,6 +62,7 @@ REGISTER_PASS1(VerifySSA); REGISTER_PASS1(CanonicalSimplify); REGISTER_PASS4(Inline); REGISTER_PASS2(StorageFlatten); +REGISTER_PASS1(VectorizeLoop); REGISTER_PASS2(UnrollLoop); REGISTER_PASS2(StorageSync); REGISTER_PASS4(MakeAPI); diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h index 9550c1c96d2c..471bda5634b4 100644 --- a/src/arithmetic/compute_expr.h +++ b/src/arithmetic/compute_expr.h @@ -9,6 +9,7 @@ #include #include +#include namespace tvm { namespace arith { @@ -52,6 +53,23 @@ inline bool GetConst(Expr e, uint64_t *out) { } } +// get a small constant int +inline bool GetConstInt(Expr e, int* out) { + int64_t v1 = 0; + uint64_t v2 = 0; + if (GetConst(e, &v1)) { + if (v1 > static_cast( + std::numeric_limits::max())) return false; + *out = static_cast(v1); return true; + } + if (GetConst(e, &v2)) { + if (v2 > static_cast( + std::numeric_limits::max())) return false; + *out = static_cast(v2); return true; + } + return false; +} + #define TVM_CONST_PROPAGATION(OP_NAME, OP) \ int64_t ia = 0, ib = 0; \ if (GetConst(a, &ia) && GetConst(b, &ib)) { \ diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 2f61e3be920f..52acf8ead925 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -3,7 +3,9 @@ * \file codegen_c.cc */ #include +#include #include "./codegen_c.h" +#include "../arithmetic/compute_expr.h" namespace tvm { namespace codegen { @@ -14,10 +16,10 @@ std::string CodeGenC::Compile(LoweredFunc f, bool output_ssa) { print_ssa_form_ = output_ssa; // skip the first underscore, so SSA variable starts from _1 - if (print_ssa_form_) GetUniqueName("_"); + GetUniqueName("_"); // add to alloc buffer type. for (const auto & kv : f->handle_data_type) { - HandleTypeRegister(kv.first.get(), kv.second.type()); + RegisterHandleType(kv.first.get(), kv.second.type()); } this->stream << "void " << f->name << "("; @@ -26,7 +28,11 @@ std::string CodeGenC::Compile(LoweredFunc f, std::string vid = AllocVarID(v.get()); if (i != 0) stream << ", "; if (v.type().is_handle()) { - stream << arg_addr_space_; + auto it = alloc_storage_scope_.find(v.get()); + if (it != alloc_storage_scope_.end()) { + PrintStorageScope(it->second, stream); + } + stream << ' '; } if (handle_data_type_.count(v.get())) { PrintType(handle_data_type_.at(v.get()), stream); @@ -126,7 +132,7 @@ bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const { return it->second == t; } -void CodeGenC::HandleTypeRegister(const Variable* buf_var, Type t) { +void CodeGenC::RegisterHandleType(const Variable* buf_var, Type t) { auto it = handle_data_type_.find(buf_var); if (it == handle_data_type_.end()) { handle_data_type_[buf_var] = t; @@ -259,23 +265,39 @@ inline void PrintBinaryExpr(const T* op, const char *opstr, std::ostream& os, // NOLINT(*) CodeGenC* p) { - os << '('; - p->PrintExpr(op->a, os); - os << opstr; - p->PrintExpr(op->b, os); - os << ')'; + if (op->type.lanes() == 1) { + if (isalpha(opstr[0])) { + os << opstr << '('; + p->PrintExpr(op->a, os); + os << ", "; + p->PrintExpr(op->b, os); + os << ')'; + } else { + os << '('; + p->PrintExpr(op->a, os); + os << ' ' << opstr << ' '; + p->PrintExpr(op->b, os); + os << ')'; + } + } else { + p->PrintVecBinaryOp(opstr, op->type, op->a, op->b, os); + } } inline void PrintBinaryIntrinsitc(const Call* op, const char *opstr, std::ostream& os, // NOLINT(*) CodeGenC* p) { - CHECK_EQ(op->args.size(), 2U); - os << '('; - p->PrintExpr(op->args[0], os); - os << opstr; - p->PrintExpr(op->args[1], os); - os << ')'; + if (op->type.lanes() == 1) { + CHECK_EQ(op->args.size(), 2U); + os << '('; + p->PrintExpr(op->args[0], os); + os << opstr; + p->PrintExpr(op->args[1], os); + os << ')'; + } else { + p->PrintVecBinaryOp(opstr, op->type, op->args[0], op->args[1], os); + } } TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) @@ -289,57 +311,49 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) os << p->GetVarID(op); }) .set_dispatch([](const Add *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, " + ", os, p); + PrintBinaryExpr(op, "+", os, p); }) .set_dispatch([](const Sub *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, " - ", os, p); + PrintBinaryExpr(op, "-", os, p); }) .set_dispatch([](const Mul *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, " * ", os, p); + PrintBinaryExpr(op, "*", os, p); }) .set_dispatch
([](const Div *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, " / ", os, p); + PrintBinaryExpr(op, "/", os, p); }) .set_dispatch([](const Mod *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, " % ", os, p); + PrintBinaryExpr(op, "%", os, p); }) .set_dispatch([](const Min *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - os << "min("; - p->PrintExpr(op->a, os); - os << ", "; - p->PrintExpr(op->b, os); - os << ")"; + PrintBinaryExpr(op, "min", os, p); }) .set_dispatch([](const Max *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - os << "max("; - p->PrintExpr(op->a, os); - os << ", "; - p->PrintExpr(op->b, os); - os << ")"; + PrintBinaryExpr(op, "max", os, p); }) .set_dispatch([](const EQ *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, " == ", os, p); + PrintBinaryExpr(op, "==", os, p); }) .set_dispatch([](const NE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, " != ", os, p); + PrintBinaryExpr(op, "!=", os, p); }) .set_dispatch([](const LT *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, " < ", os, p); + PrintBinaryExpr(op, "<", os, p); }) .set_dispatch([](const LE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, " <= ", os, p); + PrintBinaryExpr(op, "<=", os, p); }) .set_dispatch([](const GT *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, " > ", os, p); + PrintBinaryExpr(op, ">", os, p); }) .set_dispatch([](const GE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, " >= ", os, p); + PrintBinaryExpr(op, ">=", os, p); }) .set_dispatch([](const And *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, " && ", os, p); + PrintBinaryExpr(op, "&&", os, p); }) .set_dispatch([](const Or *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, " || ", os, p); + PrintBinaryExpr(op, "||", os, p); }) .set_dispatch([](const Not *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) os << '!'; @@ -460,18 +474,179 @@ void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*) } } +void CodeGenC::PrintVecBinaryOp( + const std::string&op, Type t, + Expr lhs, Expr rhs, std::ostream& os) { // NOLINT(*) + if (isalpha(op[0])) { + os << op << "("; + this->PrintExpr(lhs, os); + os << ", "; + this->PrintExpr(rhs, os); + os << ")"; + } else { + os <<"("; + this->PrintExpr(lhs, os); + os << ' ' << op << ' '; + this->PrintExpr(rhs, os); + os << ")"; + } +} + +inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) { + const Ramp* r = index.as(); + if (!r) return false; + if (!is_one(r->stride)) return false; + CHECK_EQ(r->lanes, lanes); + *base = r->base; + return true; +} + +// Print a reference expression to a buffer. +void CodeGenC::PrintBufferRef( + const Variable* buffer, + Type t, Expr index, + std::ostream& os) { // NOLINT(*) + std::string vid = GetVarID(buffer); + if (t.lanes() == 1) { + if (!HandleTypeMatch(buffer, t)) { + os << "(("; + PrintType(t, os); + os << "*)" << vid << ')'; + } else { + os << vid; + } + os << '['; + PrintExpr(index, os); + os << ']'; + } else { + // Buffer declared as vector type. + // optimize for case where it is in register, + if (HandleTypeMatch(buffer, t)) { + // optimize for constant access + int offset; + if (arith::GetConstInt(index, &offset)) { + CHECK_EQ(offset % t.lanes(), 0) + << "Find unaligned vector load to a vector type"; + os << vid << '[' << (offset / t.lanes()) << ']'; + return; + } + } + os << "(("; + PrintType(t, os); + os << "*)("; + if (!HandleTypeMatch(buffer, t.element_of())) { + os << '('; + PrintType(t.element_of(), os); + os << "*)"; + } + os << vid << " + "; + PrintExpr(index, os); + os << "))[0]"; + } +} + void CodeGenC::PrintExpr(const Load* op, std::ostream& os) { // NOLINT(*) - std::string vid = GetVarID(op->buffer_var.get()); - if (!HandleTypeMatch(op->buffer_var.get(), op->type)) { - os << "((const "; - PrintType(op->type, os); - os << "*)" << vid << ')'; + int lanes = op->type.lanes(); + if (op->type.lanes() == 1) { + this->PrintBufferRef(op->buffer_var.get(), op->type, op->index, os); } else { - os << vid; + Expr base; + if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) { + this->PrintVecLoad(op->buffer_var.get(), op->type, base, os); + } else { + // Load elements seperately + std::string sindex = SSAGetID(PrintExpr(op->index), op->index.type()); + std::string svalue = GetUniqueName("_"); + { + // delcare type. + this->PrintIndent(); + this->PrintType(op->type, stream); + stream << ' ' << svalue << ";\n"; + } + std::string vid = GetVarID(op->buffer_var.get()); + Type elem_type = op->type.element_of(); + for (int i = 0; i < lanes; ++i) { + std::ostringstream value_temp; + if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) { + value_temp << "(("; + PrintType(elem_type, os); + value_temp << "*)" << vid << ')'; + } else { + value_temp << vid; + } + value_temp << '['; + PrintVecElemLoad(sindex, op->index.type(), i, value_temp); + value_temp << ']'; + PrintVecElemStore(svalue, op->type, i, value_temp.str()); + } + os << svalue; + } } - os << '['; - PrintExpr(op->index, os); - os << ']'; +} + +void CodeGenC::PrintStmt(const Store* op) { + Type t = op->value.type(); + if (t.lanes() == 1) { + this->PrintIndent(); + std::string value = this->PrintExpr(op->value); + this->PrintBufferRef(op->buffer_var.get(), t, op->index, stream); + stream << " = " << value << ";\n"; + } else { + Expr base; + if (TryGetRamp1Base(op->index, t.lanes(), &base)) { + std::string value = this->PrintExpr(op->value); + this->PrintVecStore(op->buffer_var.get(), t, base, value); + } else { + // store elements seperately + std::string index = SSAGetID(PrintExpr(op->index), op->index.type()); + std::string value = SSAGetID(PrintExpr(op->value), op->value.type()); + std::string vid = GetVarID(op->buffer_var.get()); + for (int i = 0; i < t.lanes(); ++i) { + this->PrintIndent(); + Type elem_type = t.element_of(); + if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) { + stream << "(("; + PrintType(elem_type, stream); + stream << "*)" << vid << ')'; + } else { + stream << vid; + } + stream << '['; + PrintVecElemLoad(index, op->index.type(), i, stream); + stream << "] = "; + PrintVecElemLoad(value, op->value.type(), i, stream); + stream << ";\n"; + } + } + } +} + +void CodeGenC::PrintVecElemLoad(const std::string& vec, + Type t, int i, + std::ostream& os) { // NOLINT(*) + os << vec << ".s" << std::hex << i; +} + +void CodeGenC::PrintVecElemStore(const std::string& vec, + Type t, int i, + const std::string& value) { + this->PrintIndent(); + stream << vec << ".s" << std::hex << i + << " = " << value << ";\n"; +} + +void CodeGenC::PrintVecLoad(const Variable* buffer, + Type t, Expr base, + std::ostream& os) { + PrintBufferRef(buffer, t, base, os); +} + +void CodeGenC::PrintVecStore(const Variable* buffer, + Type t, Expr base, + const std::string& value) { + this->PrintIndent(); + PrintBufferRef(buffer, t, base, stream); + stream << " = " << value << ";\n"; } void CodeGenC::PrintExpr(const Let* op, std::ostream& os) { // NOLINT(*) @@ -483,15 +658,15 @@ void CodeGenC::PrintExpr(const Let* op, std::ostream& os) { // NOLINT(*) } void CodeGenC::PrintExpr(const Ramp* op, std::ostream& os) { // NOLINT(*) - LOG(FATAL) << "not supported "; + LOG(FATAL) << "Ramp: not supported "; } void CodeGenC::PrintExpr(const Broadcast* op, std::ostream& os) { // NOLINT(*) - LOG(FATAL) << "not supported "; + LOG(FATAL) << "Broadcast: not supported "; } void CodeGenC::PrintExpr(const Select* op, std::ostream& os) { // NOLINT(*) - LOG(FATAL) << "not supported "; + LOG(FATAL) << "Select: not supported "; } // Disoatch back to member functions @@ -541,23 +716,6 @@ void CodeGenC::PrintStmt(const LetStmt* op) { PrintStmt(op->body); } -void CodeGenC::PrintStmt(const Store* op) { - std::string index = this->PrintExpr(op->index); - std::string value = this->PrintExpr(op->value); - this->PrintIndent(); - std::string vid = GetVarID(op->buffer_var.get()); - if (!HandleTypeMatch(op->buffer_var.get(), op->value.type())) { - this->stream << "(("; - PrintType(op->value.type(), this->stream); - this->stream << "*)" << vid << ')'; - } else { - this->stream << vid; - } - this->stream << '[' << index - << "] = " << value - << ";\n"; -} - void CodeGenC::PrintStmt(const Allocate* op) { CHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); @@ -580,7 +738,7 @@ void CodeGenC::PrintStmt(const Allocate* op) { stream << ' '<< vid << '[' << constant_size << "];\n"; } - HandleTypeRegister(op->buffer_var.get(), op->type); + RegisterHandleType(op->buffer_var.get(), op->type); this->PrintStmt(op->body); } diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index d4e70379eee9..04c22b3d2a0e 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -102,6 +102,20 @@ class CodeGenC { virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*) virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*) virtual void PrintExpr(const ir::Select* op, std::ostream& os); // NOLINT(*) + // Binary vector op. + virtual void PrintVecBinaryOp( + const std::string&op, Type op_type, + Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*) + virtual void PrintVecLoad(const Variable* buffer, + Type t, Expr base, + std::ostream& os); // NOLINT(*) + virtual void PrintVecStore(const Variable* buffer, + Type t, Expr base, + const std::string& value); // NOLINT(*) + virtual void PrintVecElemLoad( + const std::string& vec, Type t, int i, std::ostream& os); // NOLINT(*) + virtual void PrintVecElemStore( + const std::string& vec, Type t, int i, const std::string& value); /*! \brief function print into the ostream */ using FPrintExpr = IRFunctor; // NOLINT(*) /*! \brief function to to print normal code */ @@ -116,17 +130,10 @@ class CodeGenC { std::ostringstream stream; protected: - // additional string for arg addr_space. - std::string arg_addr_space_; - - private: - /*! \brief entry in ssa assign map */ - struct SSAEntry { - /*! \brief The value id */ - std::string vid; - /*! \brief The scope id */ - int scope_id; - }; + // print reference to a buffer as type t in index. + void PrintBufferRef(const Variable* buffer, + Type t, Expr index, + std::ostream& os); // NOLINT(*) /*! * \brief Get the SSA ID corresponds to src * If necessary, generate new assignment @@ -134,6 +141,19 @@ class CodeGenC { * \param t The type of the expression. */ std::string SSAGetID(std::string src, Type t); + /*! + * \brief get a unique name with the corresponding prefix + * \param prefix The prefix of the name + * \return The returned name. + */ + std::string GetUniqueName(std::string prefix); + /*! \brief entry in ssa assign map */ + struct SSAEntry { + /*! \brief The value id */ + std::string vid; + /*! \brief The scope id */ + int scope_id; + }; /*! * \brief mark the beginning of a new scope * \return The scope id. @@ -155,25 +175,28 @@ class CodeGenC { * \param buf_var The buffer variable. * \param t The type to be checked. */ - void HandleTypeRegister(const Variable* buf_var, Type t); + void RegisterHandleType(const Variable* buf_var, Type t); /*! - * \brief get a unique name with the corresponding prefix - * \param prefix The prefix of the name - * \return The returned name. + * \brief Get the storage scope of buf_var. + * \param buf_var The buf_var to be queryed. + * \return The storage scope. */ - std::string GetUniqueName(std::string prefix); - /*! \brief whether to print in SSA form */ - bool print_ssa_form_{true}; - /*! \brief name of each variable */ - std::unordered_map var_idmap_; - /*! \brief the data type of allocated buffers */ - std::unordered_map handle_data_type_; + std::string GetStorageScope(const Variable* buf_var) const; + /*! \brief the storage scope of allocation */ std::unordered_map alloc_storage_scope_; + + private: + /*! \brief whether to print in SSA form */ + bool print_ssa_form_{true}; /*! \brief name allocation map */ std::unordered_map name_alloc_map_; /*! \brief assignment map of ssa */ std::unordered_map ssa_assign_map_; + /*! \brief name of each variable */ + std::unordered_map var_idmap_; + /*! \brief the data type of allocated buffers */ + std::unordered_map handle_data_type_; /*! \brief array to check whether we are inside certain scope */ std::vector scope_mark_; }; diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 06a8762a8244..deee2aa2fa36 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -22,6 +22,108 @@ std::string CodeGenCUDA::Compile( return CodeGenC::Compile(f, output_ssa); } +void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*) + int lanes = t.lanes(); + if (t.is_handle()) { + CHECK_EQ(lanes, 1) + << "do not yet support vector types"; + os << "void*"; return; + } + bool fail = false; + if (t.is_float()) { + switch (t.bits()) { + case 16: os << "half"; break; + case 32: os << "float"; break; + case 64: os << "double"; break; + default: fail = true; break; + } + if (!fail && lanes == 1) return; + if (!fail && (lanes >= 2 && lanes <= 4)) { + os << lanes; return; + } + } else if (t.is_uint() || t.is_int()) { + if (t.is_uint()) { + os << 'u'; + } + if (t.bits() == 8 && t.lanes() == 4) { + // directly 4 8 bit int in integer. + os << "int"; return; + } + switch (t.bits()) { + case 8: os << "char"; break; + case 16: os << "short"; break; + case 32: os << "int"; break; + case 64: { + if (lanes != 1 && sizeof(long) == 64) { // NOLINT(*) + os << "long"; break; + } else { + os << "int64_t"; break; + } + } + case 1: os << "int"; break; + default: fail = true; break; + } + if (!fail && lanes == 1) return; + if (!fail && (lanes >= 2 && lanes <= 4)) { + os << lanes; return; + } + } + LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; +} + +void CodeGenCUDA::PrintVecBinaryOp( + const std::string&op, Type t, + Expr lhs, Expr rhs, std::ostream& os) { // NOLINT(*) + // unpacking operations. + int lanes = t.lanes(); + + { + // default: unpack into individual ops. + std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.type()); + std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.type()); + std::string sret = GetUniqueName("_"); + { + // delcare type. + this->PrintIndent(); + this->PrintType(t, stream); + stream << ' ' << sret << ";\n"; + } + for (int i = 0; i < lanes; ++i) { + std::ostringstream value_temp; + if (isalpha(op[0])) { + value_temp << op << "("; + PrintVecElemLoad(vlhs, lhs.type(), i, value_temp); + value_temp << ", "; + PrintVecElemLoad(vrhs, rhs.type(), i, value_temp); + value_temp << ")"; + } else { + value_temp << "("; + PrintVecElemLoad(vlhs, lhs.type(), i, value_temp); + value_temp << op; + PrintVecElemLoad(vrhs, rhs.type(), i, value_temp); + value_temp << ")"; + } + PrintVecElemStore(sret, t, i, value_temp.str()); + } + os << sret; + } +} + +void CodeGenCUDA::PrintVecElemLoad( + const std::string& vec, Type t, int i, std::ostream& os) { // NOLINT(*) + const char access[] = {'x', 'y', 'z', 'w'}; + CHECK(i >= 0 && i < 4); + os << vec << "." << access[i]; +} + +void CodeGenCUDA::PrintVecElemStore( + const std::string& vec, Type t, int i, const std::string& value) { + this->PrintIndent(); + const char access[] = {'x', 'y', 'z', 'w'}; + CHECK(i >= 0 && i < 4); + stream << vec << "." << access[i] << " = " << value << ";\n"; +} + void CodeGenCUDA::PrintStorageSync(const std::string& sync) { if (sync == "shared") { this->PrintIndent(); @@ -43,8 +145,6 @@ void CodeGenCUDA::PrintStorageScope( std::unordered_map MakeNVRTC(Array funcs) { std::ostringstream os; - os << "typedef int int32_t;\n" - << "typedef unsigned unt32_t;\n"; bool output_ssa = false; for (LoweredFunc f : funcs) { os << CodeGenCUDA().Compile(f, output_ssa); @@ -56,6 +156,7 @@ MakeNVRTC(Array funcs) { const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc"); code = f(code).operator std::string(); } + LOG(INFO) << code; std::string ptx; if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) { const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile"); diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index a8cca432f49a..478faf76b74c 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -25,9 +25,18 @@ class CodeGenCUDA : public CodeGenC { */ std::string Compile(LoweredFunc f, bool output_ssa); + // override behavior void PrintStorageSync(const std::string& sync) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) + void PrintVecBinaryOp( + const std::string&op, Type t, + Expr lhs, Expr rhs, std::ostream& os) final; // NOLINT(*) + void PrintType(Type t, std::ostream& os) const final; // NOLINT(*) + void PrintVecElemLoad( + const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*) + void PrintVecElemStore( + const std::string& vec, Type t, int i, const std::string& value) final; }; } // namespace codegen diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index bafb56deb656..36cb4fff636c 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -19,7 +19,11 @@ std::string CodeGenOpenCL::Compile( LoweredFunc f, bool output_ssa) { this->stream << " __kernel "; - this->arg_addr_space_ = "__global "; + for (Var arg : f->args) { + if (arg.type().is_handle()) { + alloc_storage_scope_[arg.get()] = "global"; + } + } return CodeGenC::Compile(f, output_ssa); } @@ -34,6 +38,80 @@ void CodeGenOpenCL::PrintThreadIndexExpr( } } +void CodeGenOpenCL::PrintType(Type t, std::ostream& os) const { // NOLINT(*) + int lanes = t.lanes(); + if (t.is_handle()) { + CHECK_EQ(lanes, 1) + << "do not yet support vector types"; + os << "void*"; return; + } + bool fail = false; + if (t.is_float()) { + switch (t.bits()) { + case 16: os << "half"; break; + case 32: os << "float"; break; + case 64: os << "double"; break; + default: fail = true; break; + } + if (!fail && lanes == 1) return; + if (!fail && (lanes >= 2 && lanes <= 16)) { + os << lanes; return; + } + } else if (t.is_uint() || t.is_int()) { + if (t.is_uint()) { + os << 'u'; + } + if (t.bits() == 8 && t.lanes() == 4) { + // directly 4 8 bit int in integer. + os << "int"; return; + } + switch (t.bits()) { + case 8: os << "char"; break; + case 16: os << "short"; break; + case 32: os << "int"; break; + case 64: os << "long"; break; + case 1: os << "int"; break; + default: fail = true; break; + } + if (!fail && lanes == 1) return; + if (!fail && (lanes >= 2 && lanes <= 16)) { + os << lanes; return; + } + } + LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type"; +} + +void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, Type t, + Expr base, std::ostream& os) { // NOLINT(*) + if (!HandleTypeMatch(buffer, t.element_of())) { + os << '('; + auto it = alloc_storage_scope_.find(buffer); + if (it != alloc_storage_scope_.end()) { + PrintStorageScope(it->second, os); + } + os << ' '; + PrintType(t.element_of(), os); + os << "*)"; + } + os << GetVarID(buffer) << " + "; + PrintExpr(base, os); +} +void CodeGenOpenCL::PrintVecLoad(const Variable* buffer, + Type t, Expr base, + std::ostream& os) { + os << "vload" << t.lanes() << "(0, "; + PrintVecAddr(buffer, t, base, os); + os << ")"; +} + +void CodeGenOpenCL::PrintVecStore(const Variable* buffer, + Type t, Expr base, + const std::string& value) { + this->PrintIndent(); + stream << "vstore" << t.lanes() << "(" << value << ", 0, "; + PrintVecAddr(buffer, t, base, stream); + stream << ");\n"; +} void CodeGenOpenCL::PrintStorageSync(const std::string& sync) { if (sync == "shared") { @@ -45,8 +123,9 @@ void CodeGenOpenCL::PrintStorageSync(const std::string& sync) { } void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) - CHECK_NE(scope, "global"); - if (scope == "shared") { + if (scope == "global") { + os << "__global"; + } else if (scope == "shared") { os << "__local "; } } @@ -55,8 +134,6 @@ void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os std::unordered_map MakeOpenCL(Array funcs) { std::ostringstream os; - os << "typedef int int32_t;\n" - << "typedef unsigned unt32_t;\n"; bool output_ssa = false; for (LoweredFunc f : funcs) { os << CodeGenOpenCL().Compile(f, output_ssa); diff --git a/src/codegen/codegen_opencl.h b/src/codegen/codegen_opencl.h index a0b8120f1c30..58147f8b9d79 100644 --- a/src/codegen/codegen_opencl.h +++ b/src/codegen/codegen_opencl.h @@ -30,6 +30,16 @@ class CodeGenOpenCL : public CodeGenC { std::string tag, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const std::string& scope) final; // NOLINT(*) + void PrintType(Type t, std::ostream& os) const final; // NOLINT(*) + void PrintVecLoad(const Variable* buffer, + Type t, Expr base, + std::ostream& os) final; // NOLINT(*) + void PrintVecStore(const Variable* buffer, + Type t, Expr base, + const std::string& value) final; // NOLINT(*) + // the address of load/store + void PrintVecAddr(const Variable* buffer, Type t, + Expr base, std::ostream& os); // NOLINT(*) }; } // namespace codegen diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index c6b7e6b51c85..3c39b6c50afc 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -74,6 +74,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) .DISPATCH_TO_MUTATE_STMT(Provide) .DISPATCH_TO_MUTATE_STMT(Realize) .DISPATCH_TO_MUTATE_STMT(Store) +.DISPATCH_TO_MUTATE_STMT(IfThenElse) .DISPATCH_TO_MUTATE_STMT(For) .DISPATCH_TO_MUTATE_STMT(Allocate) .DISPATCH_TO_MUTATE_STMT(Free); @@ -195,6 +196,22 @@ Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) { return s; } +Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { + Expr condition = this->Mutate(op->condition); + Stmt then_case = this->Mutate(op->then_case); + Stmt else_case; + if (else_case.defined()) { + else_case = this->Mutate(op->else_case); + } + if (condition.same_as(op->condition) && + then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return s; + } else { + return IfThenElse::make(condition, then_case, else_case); + } +} + TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) .DISPATCH_TO_MUTATE_EXPR(Call) .DISPATCH_TO_MUTATE_EXPR(Let) @@ -363,21 +380,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) return Block::make(first, rest); } }) -.set_dispatch([](const IfThenElse *op, const Stmt& s, IRMutator* m) { - Expr condition = m->Mutate(op->condition); - Stmt then_case = m->Mutate(op->then_case); - Stmt else_case; - if (else_case.defined()) { - else_case = m->Mutate(op->else_case); - } - if (condition.same_as(op->condition) && - then_case.same_as(op->then_case) && - else_case.same_as(op->else_case)) { - return s; - } else { - return IfThenElse::make(condition, then_case, else_case); - } - }) .set_dispatch([](const Evaluate *op, const Stmt& s, IRMutator* m) { Expr v = m->Mutate(op->value); if (v.same_as(op->value)) { diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc index 186733fa2f71..0a26493803bc 100644 --- a/src/pass/split_host_device.cc +++ b/src/pass/split_host_device.cc @@ -101,9 +101,14 @@ class IRUseDefAnalysis : public IRMutator { } void HandleDef(const Variable* v) { + CHECK(!def_count_.count(v)) + << "variable " << v->name_hint + << " has already been defined, the Stmt is not SSA"; CHECK(!use_count_.count(v)) - << "variable is already defined"; + << "variable " << v->name_hint + << " has been used before definition!"; use_count_[v] = 0; + def_count_[v] = 1; } void HandleUse(const Expr& v) { @@ -127,6 +132,7 @@ class IRUseDefAnalysis : public IRMutator { Array thread_axis_; Array thread_extent_; std::unordered_map use_count_; + std::unordered_map def_count_; }; class HostDeviceSplitter : public IRMutator { diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc index 555e5b970a05..ddecf4475766 100644 --- a/src/pass/unroll_loop.cc +++ b/src/pass/unroll_loop.cc @@ -1,7 +1,7 @@ /*! - * Copyright (c) 2016 by Contributors - * SSA related checks and pass. - * \file ssa.cc + * Copyright (c) 2017 by Contributors + * Loop unrolling. + * \file unroll_loop.cc */ #include #include @@ -9,7 +9,7 @@ #include #include #include -#include "../arithmetic//compute_expr.h" +#include "../arithmetic/compute_expr.h" namespace tvm { namespace ir { @@ -33,7 +33,8 @@ class LoopUnroller : public IRMutator { if (v2 != nullptr) { value = static_cast(v2->value); } - bool allow_unroll = value >= 0 && value <= max_auto_step_; + bool allow_unroll = (op->for_type == ForType::Serial && + value >= 0 && value <= max_auto_step_); if (op->for_type == ForType::Unrolled) { CHECK_GE(value, 0) << "Cannot unroll non-constant loop"; diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc new file mode 100644 index 000000000000..109b1326f8cb --- /dev/null +++ b/src/pass/vectorize_loop.cc @@ -0,0 +1,385 @@ +/*! + * Copyright (c) 2017 by Contributors + * Vectorize the loop + * \file vectorize_loop.cc + */ +#include +#include +#include +#include +#include +#include +#include "../arithmetic/compute_expr.h" + +namespace tvm { +namespace ir { + +inline Expr BroadcastTo(Expr e, int lanes) { + if (e.type().lanes() == lanes) return e; + CHECK_EQ(e.type().lanes(), 1) + << "Cannot broadcast lane=" << e.type().lanes() + << " to " << lanes; + return Broadcast::make(e, lanes); +} + +// Rewrite vectorized allocation access +// s[i] = s[i * lanes + var] +class VecAllocAccess : public IRMutator { + public: + VecAllocAccess(const Variable* buf, Var var, int var_lanes) + : buf_(buf), var_(var), var_lanes_(var_lanes) {} + // Load + Expr Mutate_(const Load* op, const Expr& e) final { + Expr expr = IRMutator::Mutate_(op, e); + op = expr.as(); + if (op->buffer_var.get() == buf_) { + return Load::make(op->type, op->buffer_var, + op->index * var_lanes_ + var_); + } else { + return expr; + } + } + // Store + Stmt Mutate_(const Store* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + if (op->buffer_var.get() == buf_) { + return Store::make(op->buffer_var, + op->value, + op->index * var_lanes_ + var_); + } else { + return stmt; + } + } + + private: + // buffer var + const Variable* buf_; + // variable to be replaced + Var var_; + // the lanes. + int var_lanes_; +}; + +class Vectorizer : public IRMutator { + public: + Vectorizer(Var var, int var_lanes) + : var_(var), var_lanes_(var_lanes) { + ramp_ = Ramp::make(0, 1, var_lanes); + } + // user mutate from parent. + using IRMutator::Mutate; + // override mutate + Expr Mutate(Expr expr) final { + static const FMutateExpr& f = Vectorizer::vtable_expr(); + return (f.can_dispatch(expr) ? + f(expr, expr, this) : IRMutator::Mutate(expr)); + } + // Variable + Expr Mutate_(const Variable* v, const Expr& e) final { + if (v == var_.get()) { + return ramp_; + } else if (lets_.count(v)) { + return lets_[v]; + } else { + return e; + } + } + // Call + Expr Mutate_(const Call* op, const Expr& e) final { + int lane = 0; + Array new_args = MutateArray(op->args, &lane); + if (op->args.same_as(new_args)) { + return e; + } else { + return Call::make( + op->type.with_lanes(lane), op->name, new_args, + op->call_type, op->func, op->value_index); + } + } + // Load + Expr Mutate_(const Load* op, const Expr& e) final { + Expr index = this->Mutate(op->index); + if (index.same_as(op->index)) { + return e; + } else { + return Load::make(op->type.with_lanes(index.type().lanes()), + op->buffer_var, index); + } + } + // Let + Expr Mutate_(const Let* op, const Expr& e) final { + Expr value = this->Mutate(op->value); + CHECK(!lets_.count(op->var.get())) << "not SSA"; + if (value.type().lanes() != op->value.type().lanes()) { + Var v(op->var->name_hint, value.type()); + lets_[op->var.get()] = v; + return Let::make(v, value, Mutate(op->body)); + } else { + Expr body = this->Mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return e; + } else { + return Let::make(op->var, value, body); + } + } + } + // Provide + Stmt Mutate_(const Provide* op, const Stmt& s) final { + Expr new_value = this->Mutate(op->value); + int lane = new_value.type().lanes(); + Array new_args = MutateArray(op->args, &lane); + if (op->args.same_as(new_args) && op->value.same_as(new_value)) { + return s; + } else { + new_value = BroadcastTo(new_value, lane); + return Provide::make(op->func, op->value_index, new_value, new_args); + } + } + // Store + Stmt Mutate_(const Store* op, const Stmt& s) final { + Expr value = this->Mutate(op->value); + Expr index = this->Mutate(op->index); + if (value.same_as(op->value) && index.same_as(op->index)) { + return s; + } else { + int lanes = std::max(value.type().lanes(), index.type().lanes()); + return Store::make(op->buffer_var, + BroadcastTo(value, lanes), + BroadcastTo(index, lanes)); + } + } + // For + Stmt Mutate_(const For* op, const Stmt& s) final { + if (op->for_type == ForType::Vectorized) { + LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring..."; + } + CHECK(is_zero(op->min)); + CHECK(!op->extent.type().is_vector()); + Expr extent = Mutate(op->extent); + if (extent.type().is_vector()) { + LOG(WARNING) << "Detect vectorized extent type, scalarizing..."; + return Scalarize(s); + } + Stmt body = Mutate(op->body); + if (extent.same_as(op->extent) && + body.same_as(op->body)) { + return s; + } else { + return For::make( + op->loop_var, op->min, extent, + op->for_type, op->device_api, body); + } + } + // IfThenElse + Stmt Mutate_(const IfThenElse* op, const Stmt& s) final { + CHECK(!op->condition.type().is_vector()); + Expr condition = this->Mutate(op->condition); + if (condition.type().is_vector()) { + LOG(WARNING) << "Detect vector condition in Vectorized Loop, scalarizing..."; + return Scalarize(s); + } + Stmt then_case = this->Mutate(op->then_case); + Stmt else_case; + if (else_case.defined()) { + else_case = this->Mutate(op->else_case); + } + if (condition.same_as(op->condition) && + then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return s; + } else { + return IfThenElse::make(condition, then_case, else_case); + } + } + // LetStmt + Stmt Mutate_(const LetStmt* op, const Stmt& s) final { + LOG(WARNING) << "Cannot vectorize with LetStmt, remove it with Simplify Before Vectorize"; + return Scalarize(s); + } + // Allocate + Stmt Mutate_(const Allocate* op, const Stmt& s) final { + if (op->new_expr.defined()) { + LOG(WARNING) << "Cannot vectorize with new expr"; + return Scalarize(s); + } + Expr condition = Mutate(op->condition); + if (condition.type().is_vector()) { + LOG(WARNING) << "Cannot handle vector extent in alloc "; + return Scalarize(s); + } + Array extents; + for (size_t i = 0; i < op->extents.size(); i++) { + Expr new_ext = Mutate(op->extents[i]); + if (new_ext.type().is_vector()) { + LOG(WARNING) << "Cannot handle vector extent in alloc "; + return Scalarize(s); + } + extents.push_back(new_ext); + } + // place the vector lanes in least significant dimension. + extents.push_back(var_lanes_); + // rewrite access to buffer internally. + Stmt body = VecAllocAccess( + op->buffer_var.get(), var_, var_lanes_).Mutate(op->body); + body = Mutate(body); + return Allocate::make( + op->buffer_var, op->type, + extents, condition, body, + op->new_expr, op->free_function); + } + // scalarize the statment + Stmt Scalarize(Stmt stmt) { + Var idx(var_->name_hint + ".s", var_->type); + stmt = Substitute(stmt, {{var_, idx}}); + return For::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt); + } + // The overloads for vectorize. + static FMutateExpr& vtable_expr() { // NOLINT(*) + static FMutateExpr inst; return inst; + } + + private: + // variable to be replaced + Var var_; + // the lanes. + int var_lanes_; + // ramp representing the var. + Expr ramp_; + // The lets + std::unordered_map lets_; + // mutate array, with given lane requirement + // when finished, p_lane updates the lane requirement. + Array MutateArray(Array arr, int* p_lanes) { + if (arr.size() == 0) return arr; + int& lanes = *p_lanes; + bool changed = false; + std::vector new_arr(arr.size()); + for (size_t i = 0; i < arr.size(); i++) { + Expr old_elem = arr[i]; + Expr new_elem = this->Mutate(old_elem); + if (!new_elem.same_as(old_elem)) changed = true; + new_arr[i] = new_elem; + lanes = std::max(lanes, new_elem.type().lanes()); + } + + for (size_t i = 0; i < arr.size(); ++i) { + if (new_arr[i].type().lanes() != lanes) { + new_arr[i] = BroadcastTo(new_arr[i], lanes); + changed = true; + } + } + if (!changed) return arr; + return Array(new_arr); + } +}; + +// binary vectorize +template +inline Expr BinaryVec(const T* op, const Expr& e, IRMutator* m) { + Expr a = m->Mutate(op->a); + Expr b = m->Mutate(op->b); + if (a.same_as(op->a) && + b.same_as(op->b)) { + return e; + } else { + int lanes = std::max(a.type().lanes(), b.type().lanes()); + return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + } +} + +template +inline Expr AddSubVec(const T* op, const Expr& e, IRMutator* m) { + Expr a = m->Mutate(op->a); + Expr b = m->Mutate(op->b); + if (a.same_as(op->a) && + b.same_as(op->b)) { + return e; + } else { + int lanes = std::max(a.type().lanes(), b.type().lanes()); + if (lanes != 1) { + const Ramp* b_ramp = b.as(); + const Ramp* a_ramp = a.as(); + if (a.type().lanes() == 1 && b_ramp) { + return Ramp::make( + arith::ComputeExpr(a, b_ramp->base), b_ramp->stride, b_ramp->lanes); + } + if (b.type().lanes() == 1 && a_ramp) { + return Ramp::make( + arith::ComputeExpr(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); + } + } + return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + } +} + +TVM_STATIC_IR_FUNCTOR(Vectorizer, vtable_expr) +.set_dispatch(AddSubVec) +.set_dispatch(AddSubVec) +.set_dispatch(BinaryVec) +.set_dispatch
(BinaryVec
) +.set_dispatch(BinaryVec) +.set_dispatch(BinaryVec) +.set_dispatch(BinaryVec) +.set_dispatch(BinaryVec) +.set_dispatch(BinaryVec) +.set_dispatch(BinaryVec) +.set_dispatch(BinaryVec) +.set_dispatch(BinaryVec) +.set_dispatch(BinaryVec) +.set_dispatch(BinaryVec) +.set_dispatch(BinaryVec); + + +TVM_STATIC_IR_FUNCTOR(Vectorizer, vtable_expr) +.set_dispatch