diff --git a/src/script/builder/builder.cc b/src/script/builder/builder.cc index 4c8fc70f2615..f13f90f6953c 100644 --- a/src/script/builder/builder.cc +++ b/src/script/builder/builder.cc @@ -48,20 +48,6 @@ void Builder::ExitWithScope() { std::vector* stack = ThreadLocalBuilderStack(); ICHECK(!stack->empty()); stack->pop_back(); - // IRModuleFrame frame = Downcast(n->frames.back()); - // n->frames.pop_back(); - // if (!frame->stmts.empty()) { - // ICHECK(frame->global_vars.empty()); - // ICHECK(frame->functions.empty()); - // n->result = frame->stmts; - // } else { - // Map func_map; - // ICHECK_EQ(frame->functions.size(), frame->global_vars.size()); - // int m = frame->functions.size(); - // for (int i = 0; i < m; ++i) { - // func_map.Set(frame->global_vars[i], frame->functions[i]); - // } - // } } Builder Builder::Current() { @@ -70,6 +56,28 @@ Builder Builder::Current() { return stack->back(); } +Namer::FType& Namer::vtable() { + static FType inst; + return inst; +} + +void Namer::Name(ObjectRef node, String name) { + static const FType& f = vtable(); + CHECK(node.defined()) << "ValueError: Cannot name nullptr with: " << name; + CHECK(f.can_dispatch(node)) << "ValueError: Do not know how to name type \"" + << node->GetTypeKey(); + f(node, name); +} + +namespace details { + +ObjectRef DefImpl(String name, ObjectRef obj) { + Namer::Name(obj, name); + return obj; +} + +} // namespace details + TVM_REGISTER_NODE_TYPE(BuilderNode); } // namespace builder diff --git a/src/script/builder/builder.h b/src/script/builder/builder.h index 506ba2030d69..0bbfee9688e5 100644 --- a/src/script/builder/builder.h +++ b/src/script/builder/builder.h @@ -59,6 +59,26 @@ class Builder : public runtime::ObjectRef { static Builder Current(); }; +template +inline TObjectRef Def(String name, TObjectRef obj); + +namespace details { +ObjectRef DefImpl(String name, ObjectRef obj); +} + +class Namer { + public: + using FType = NodeFunctor; + static FType& vtable(); + + static void Name(ObjectRef node, String name); +}; + +template +inline TObjectRef Def(String name, TObjectRef obj) { + return Downcast(details::DefImpl(name, obj)); +} + template inline Optional BuilderNode::FindFrame() const { using TFrameNode = typename TFrame::ContainerType; diff --git a/src/script/builder/frame.cc b/src/script/builder/frame.cc index 4fe10c2cc630..9359868ef0e6 100644 --- a/src/script/builder/frame.cc +++ b/src/script/builder/frame.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "./builder.h" namespace tvm { @@ -44,7 +46,20 @@ IRModuleFrame::IRModuleFrame() { data_ = std::move(n); } +void IRModuleFrameNode::ExitWithScope() { + ICHECK_EQ(functions.size(), global_vars.size()); + int n = functions.size(); + Map func_map; + for (int i = 0; i < n; ++i) { + func_map.Set(global_vars[i], functions[i]); + } + Builder builder = Builder::Current(); + ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + builder->result = tvm::IRModule(func_map); +} + TVM_REGISTER_NODE_TYPE(FrameNode); +TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); } // namespace builder } // namespace script diff --git a/src/script/builder/frame.h b/src/script/builder/frame.h index bcb1d90c88f9..0f86f326dafe 100644 --- a/src/script/builder/frame.h +++ b/src/script/builder/frame.h @@ -69,6 +69,9 @@ class IRModuleFrameNode : public FrameNode { static constexpr const char* _type_key = "script.builder.IRModuleFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, FrameNode); + + public: + void ExitWithScope() final; }; class IRModuleFrame : public Frame { diff --git a/src/script/builder/tir/base.cc b/src/script/builder/tir/base.cc index db0c1c3bf939..d5206c9a7348 100644 --- a/src/script/builder/tir/base.cc +++ b/src/script/builder/tir/base.cc @@ -42,18 +42,18 @@ void TestPOC() { With builder; { With _{T::PrimFunc_("main")}; - Buffer A = T::Arg(T::Buffer_({128, 128, 128}, DataType::Float(32))); - Buffer B = T::Arg(T::Buffer_({128, 128, 128}, DataType::Float(32))); + Buffer A = T::Arg("A", T::Buffer_({128, 128, 128}, DataType::Float(32))); + Buffer B = T::Arg("B", T::Buffer_({128, 128, 128}, DataType::Float(32))); { With _{T::Grid({128, 128, 128})}; - Var i = _()->vars[0]; - Var j = _()->vars[1]; - Var k = _()->vars[2]; + Var i = Def("i", _()->vars[0]); + Var j = Def("j", _()->vars[1]); + Var k = Def("k", _()->vars[2]); { With _{T::Block_("block")}; - IterVar vi = T::axis::Spatial(Range(0, 128), i); - IterVar vj = T::axis::Spatial(Range(0, 128), j); - IterVar vk = T::axis::Reduce(Range(0, 128), k); + IterVar vi = Def("vi", T::axis::Spatial(Range(0, 128), i)); + IterVar vj = Def("vj", T::axis::Spatial(Range(0, 128), j)); + IterVar vk = Def("vk", T::axis::Reduce(Range(0, 128), k)); } LOG(INFO) << "ForFrame:\n" << _()->stmts; } diff --git a/src/script/builder/tir/prim_func_frame.cc b/src/script/builder/tir/prim_func_frame.cc index d052624a6123..039a6ecdef56 100644 --- a/src/script/builder/tir/prim_func_frame.cc +++ b/src/script/builder/tir/prim_func_frame.cc @@ -55,14 +55,16 @@ PrimFuncFrame PrimFunc_(String name) { return PrimFuncFrame(n); } -tvm::tir::Var Arg(tvm::tir::Var var) { +tvm::tir::Var Arg(String name, tvm::tir::Var var) { + Namer::Name(var, name); PrimFuncFrame frame = Builder::Current()->FindFrame().value(); frame->args.push_back(var); return var; } -tvm::tir::Buffer Arg(tvm::tir::Buffer buffer) { +tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer) { using namespace tvm::tir; + Namer::Name(buffer, name); PrimFuncFrame frame = Builder::Current()->FindFrame().value(); Var handle(buffer->name + "_handle", DataType::Handle()); frame->args.push_back(handle); diff --git a/src/script/builder/tir/prim_func_frame.h b/src/script/builder/tir/prim_func_frame.h index 7da51dbafbbe..11a6a564deff 100644 --- a/src/script/builder/tir/prim_func_frame.h +++ b/src/script/builder/tir/prim_func_frame.h @@ -54,8 +54,8 @@ class PrimFuncFrame : public TIRFrame { }; PrimFuncFrame PrimFunc_(String name); -tvm::tir::Var Arg(tvm::tir::Var var); -tvm::tir::Buffer Arg(tvm::tir::Buffer buffer); +tvm::tir::Var Arg(String name, tvm::tir::Var var); +tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer); } // namespace tir } // namespace builder diff --git a/src/script/builder/tir/var.cc b/src/script/builder/tir/var.cc index f2e77d763e8e..01ea3a01aad8 100644 --- a/src/script/builder/tir/var.cc +++ b/src/script/builder/tir/var.cc @@ -27,6 +27,42 @@ tvm::tir::Buffer Buffer_(Array shape, DataType dtype, String name, Str return tvm::tir::decl_buffer(shape, dtype, name, storage_scope); } +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using namespace tvm::tir; + BufferNode* buffer = const_cast(node.as()); + buffer->name = name; + Namer::Name(buffer->data, name + "_data"); + int n = buffer->strides.size(); + for (int i = 0; i < n; ++i) { + PrimExpr e = buffer->strides[i]; + if (const VarNode* v = e.as()) { + Namer::Name(GetRef(v), name + "_s" + std::to_string(i)); + } + } + }); + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using namespace tvm::tir; + SizeVarNode* var = const_cast(node.as()); + var->name_hint = name; + }); + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using namespace tvm::tir; + VarNode* var = const_cast(node.as()); + var->name_hint = name; + }); + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using namespace tvm::tir; + IterVarNode* var = const_cast(node.as()); + Namer::Name(var->var, name); + }); + } // namespace tir } // namespace builder } // namespace script