diff --git a/src/relay/backend/contrib/verilator/codegen.cc b/src/relay/backend/contrib/verilator/codegen.cc index 2f61ae540395..b206288f7e96 100644 --- a/src/relay/backend/contrib/verilator/codegen.cc +++ b/src/relay/backend/contrib/verilator/codegen.cc @@ -34,6 +34,7 @@ #include #include "../../../../runtime/contrib/json/json_node.h" +#include "../../../../runtime/contrib/verilator/verilator_runtime.h" #include "../../utils.h" #include "../codegen_json/codegen_json.h" @@ -75,29 +76,34 @@ class VerilatorJSONSerializer : public backend::contrib::JSONSerializer { } }; -/*! \brief Attributes to store the compiler options for Verilator */ -struct VerilatorCompilerConfigNode : public tvm::AttrsNode { - String lib; - - TVM_DECLARE_ATTRS(VerilatorCompilerConfigNode, "ext.attrs.VerilatorCompilerConfigNode") { - TVM_ATTR_FIELD(lib).set_default("libverilator.so"); +/*! \brief Attributes to store options for Verilator */ +struct VerilatorOptionsNode : public tvm::AttrsNode { + String lib_path; + int reset_cycles; + bool profiler_enable; + int profiler_cycle_counter_id; + + TVM_DECLARE_ATTRS(VerilatorOptionsNode, "ext.attrs.VerilatorOptionsNode") { + TVM_ATTR_FIELD(lib_path).describe("the design library path").set_default("libverilator.so"); + TVM_ATTR_FIELD(reset_cycles).describe("the number of reset cycles").set_default(1); + TVM_ATTR_FIELD(profiler_enable).describe("enable profiler").set_default(false); + TVM_ATTR_FIELD(profiler_cycle_counter_id).describe("profiler cycle counter id").set_default(0); } }; -class VerilatorCompilerConfig : public Attrs { +class VerilatorOptions : public Attrs { public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VerilatorCompilerConfig, Attrs, - VerilatorCompilerConfigNode); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VerilatorOptions, Attrs, VerilatorOptionsNode); }; -TVM_REGISTER_NODE_TYPE(VerilatorCompilerConfigNode); -TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.verilator.options", VerilatorCompilerConfig); +TVM_REGISTER_NODE_TYPE(VerilatorOptionsNode); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.verilator.options", VerilatorOptions); /*! - * \brief The external compiler/codegen tool. It takes a Relay expression/module and - * compile it into a runtime module. + * \brief The Verilator codegen tool. It takes a Relay expression/module and + * compile it into a Verilator runtime module. */ -runtime::Module VerilatorCompiler(const ObjectRef& ref) { +runtime::Module VerilatorBackend(const ObjectRef& ref) { CHECK(ref->IsInstance()); auto func = Downcast(ref); auto func_name = GetExtSymbol(func); @@ -106,22 +112,28 @@ runtime::Module VerilatorCompiler(const ObjectRef& ref) { std::string graph_json = serializer.GetJSON(); auto params = serializer.GetParams(); + // Create runtime object + auto n = make_object(func_name, graph_json, params); + // Get Verilator compiler options auto ctx = transform::PassContext::Current(); - auto cfg = ctx->GetConfig("relay.ext.verilator.options"); + auto cfg = ctx->GetConfig("relay.ext.verilator.options"); if (!cfg.defined()) { - cfg = AttrsWithDefaultValues(); + cfg = AttrsWithDefaultValues(); } - auto lib_name = cfg.value()->lib; + n->SetLibrary(cfg.value()->lib_path); + n->SetResetCycles(cfg.value()->reset_cycles); + + if (cfg.value()->profiler_enable) { + n->EnableProfiler(); + n->SetProfilerCycleCounterId(cfg.value()->profiler_cycle_counter_id); + } - const auto* pf = runtime::Registry::Get("runtime.verilator_runtime_create"); - CHECK(pf != nullptr) << "Cannot find JSON runtime module to create"; - auto mod = (*pf)(lib_name, func_name, graph_json, params); - return mod; + return runtime::Module(n); } -TVM_REGISTER_GLOBAL("relay.ext.verilator").set_body_typed(VerilatorCompiler); +TVM_REGISTER_GLOBAL("relay.ext.verilator").set_body_typed(VerilatorBackend); } // namespace contrib } // namespace relay diff --git a/src/runtime/contrib/verilator/verilator_device.h b/src/runtime/contrib/verilator/verilator_device.h index acd91a53bcff..298e41c06daf 100644 --- a/src/runtime/contrib/verilator/verilator_device.h +++ b/src/runtime/contrib/verilator/verilator_device.h @@ -31,24 +31,51 @@ namespace tvm { namespace runtime { namespace contrib { +/*! \brief Verilator device resource context */ typedef void* VerilatorHandle; -/* allocate Verilator object */ +/*! + * \brief Allocate a verilator device resource handle + * \return The verilator device handle. + */ extern "C" TVM_DLL VerilatorHandle VerilatorAlloc(); -/* deallocate Verilator object */ +/*! + * \brief Free a verilator device handle + * \param handle The verilator device handle to be freed. + */ extern "C" TVM_DLL void VerilatorDealloc(VerilatorHandle handle); -/* read Verilator register or memory */ +/*! + * \brief Read verilator register or memory + * \param handle The verilator device handle. + * \param id The register or memory identifier. + * \param addr The register or memory address (word-level). + * \return The value of register or memory. + */ extern "C" TVM_DLL int VerilatorRead(VerilatorHandle handle, int id, int addr); -/* write Verilator register or memory */ +/*! + * \brief Write verilator register or memory + * \param handle The verilator device handle. + * \param id The register or memory identifier. + * \param addr The register or memory address (word-level). + * \param value The value of register or memory. + */ extern "C" TVM_DLL void VerilatorWrite(VerilatorHandle handle, int id, int addr, int value); -/* reset Verilator for n clock cycles */ +/*! + * \brief Reset Verilator for n clock cycles + * \param handle The verilator device handle. + * \param n The number of reset cycles. + */ extern "C" TVM_DLL void VerilatorReset(VerilatorHandle handle, int n); -/* run Verilator for n clock cycles */ +/*! + * \brief Run Verilator for n clock cycles + * \param handle The verilator device handle. + * \param n The number of run cycles. + */ extern "C" TVM_DLL void VerilatorRun(VerilatorHandle handle, int n); } // namespace contrib diff --git a/src/runtime/contrib/verilator/verilator_runtime.cc b/src/runtime/contrib/verilator/verilator_runtime.cc index 60f36e494da7..bc96b69f2ffe 100644 --- a/src/runtime/contrib/verilator/verilator_runtime.cc +++ b/src/runtime/contrib/verilator/verilator_runtime.cc @@ -19,9 +19,11 @@ /*! * \file src/runtime/contrib/verilator/verilator_runtime.cc - * \brief A simple JSON runtime for Verilator. + * \brief A runtime for Verilator. */ +#include "verilator_runtime.h" + #include #include #include @@ -40,124 +42,123 @@ namespace tvm { namespace runtime { namespace contrib { -typedef VerilatorHandle (*VerilatorAllocFunc)(); -typedef void (*VerilatorResetFunc)(VerilatorHandle, int); -typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int); - using namespace tvm::runtime; +using namespace tvm::runtime::contrib; using namespace tvm::runtime::json; -class VerilatorLibrary : public Library { - public: - ~VerilatorLibrary() { - if (lib_handle_) Unload(); - } - void Init(const std::string& name) { Load(name); } - - void* GetSymbol(const char* name) final { return GetSymbol_(name); } - - private: - // Library handle - void* lib_handle_{nullptr}; - // load the library - void Load(const std::string& name) { - lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); - ICHECK(lib_handle_ != nullptr) - << "Failed to load dynamic shared library " << name << " " << dlerror(); - } - - void* GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } - - void Unload() { +VerilatorLibrary::~VerilatorLibrary() { + if (lib_handle_) { dlclose(lib_handle_); lib_handle_ = nullptr; } -}; +} -class VerilatorJSONRuntime : public JSONRuntimeBase { - public: - VerilatorJSONRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array const_names) - : JSONRuntimeBase(symbol_name, graph_json, const_names) {} +void VerilatorLibrary::Load(const std::string& name) { + lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); + ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " " + << dlerror(); +} - const char* type_key() const { return "verilator_json"; } +void* VerilatorLibrary::GetSymbol(const char* name) { return dlsym(lib_handle_, name); } - void LoadLibrary(const std::string& lib_name) { - lib_ = new VerilatorLibrary(); - lib_->Init(lib_name); - } +void VerilatorProfiler::Clear() { cycle_counter = 0; } - void Init(const Array& consts) override { - // get symbols - auto alloc_func = reinterpret_cast(lib_->GetSymbol("VerilatorAlloc")); - ICHECK(alloc_func != nullptr); - auto reset_func = reinterpret_cast(lib_->GetSymbol("VerilatorReset")); - ICHECK(reset_func != nullptr); - vadd_func_ = reinterpret_cast(lib_->GetSymbol("verilator_add")); - ICHECK(vadd_func_ != nullptr); +std::string VerilatorProfiler::AsJSON() { + std::ostringstream os; + os << "{\n" + << " \"cycle_counter\":" << cycle_counter << "\n" + << "}\n"; + return os.str(); +} - // alloc device - device_ = (*alloc_func)(); +VerilatorProfiler* VerilatorProfiler::ThreadLocal() { + static thread_local VerilatorProfiler inst; + return &inst; +} - // reset for 10 cycles - (*reset_func)(device_, 10); +VerilatorRuntime::~VerilatorRuntime() { + auto dealloc = reinterpret_cast(lib_->GetSymbol("VerilatorDealloc")); + ICHECK(dealloc != nullptr); + dealloc(device_); + lib_->~VerilatorLibrary(); +} - CHECK_EQ(consts.size(), const_idx_.size()) - << "The number of input constants must match the number of required."; +void VerilatorRuntime::SetLibrary(const std::string& lib_path) { lib_path_ = lib_path; } - // Setup constants entries for weights. - SetupConstants(consts); - } +void VerilatorRuntime::SetResetCycles(const int cycles) { reset_cycles_ = cycles; } - void Run() override { - std::vector in_ptr; - std::vector out_ptr; - for (size_t i = 0; i < input_nodes_.size(); ++i) { - uint32_t eid = EntryID(input_nodes_[i], 0); - int* data = static_cast(data_entry_[eid]->data); - in_ptr.push_back(data); - } - for (size_t i = 0; i < outputs_.size(); ++i) { - uint32_t eid = EntryID(outputs_[i]); - int* data = static_cast(data_entry_[eid]->data); - out_ptr.push_back(data); - } - for (size_t nid = 0; nid < nodes_.size(); ++nid) { - const auto& node = nodes_[nid]; - if (node.GetOpType() == "kernel") { - CHECK_EQ(node.GetOpType(), "kernel"); - auto op_name = node.GetOpName(); - if ("add" == op_name) { - auto entry = node.GetInputs()[0]; - auto shape = nodes_[entry.id_].GetOpShape()[entry.index_]; - (*vadd_func_)(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]); - } else { - LOG(FATAL) << "Unsupported op: " << op_name; - } +void VerilatorRuntime::EnableProfiler() { prof_enable_ = true; } + +void VerilatorRuntime::SetProfilerCycleCounterId(const int id) { prof_cycle_counter_id_ = id; } + +void VerilatorRuntime::Init(const Array& consts) { + lib_ = new VerilatorLibrary(); + lib_->Load(lib_path_); + auto alloc = reinterpret_cast(lib_->GetSymbol("VerilatorAlloc")); + ICHECK(alloc != nullptr); + auto reset = reinterpret_cast(lib_->GetSymbol("VerilatorReset")); + ICHECK(reset != nullptr); + read_ = reinterpret_cast(lib_->GetSymbol("VerilatorRead")); + ICHECK(read_ != nullptr); + add_op_ = reinterpret_cast(lib_->GetSymbol("verilator_add")); + + // alloc verilator device + device_ = alloc(); + + // enable profiler + if (prof_enable_) prof_ = VerilatorProfiler::ThreadLocal(); + + // reset verilator device + reset(device_, reset_cycles_); + + CHECK_EQ(consts.size(), const_idx_.size()) + << "The number of input constants must match the number of required."; + + // Setup constants entries for weights. + SetupConstants(consts); +} + +void VerilatorRuntime::Run() { + std::vector in_ptr; + std::vector out_ptr; + for (size_t i = 0; i < input_nodes_.size(); ++i) { + uint32_t eid = EntryID(input_nodes_[i], 0); + int* data = static_cast(data_entry_[eid]->data); + in_ptr.push_back(data); + } + for (size_t i = 0; i < outputs_.size(); ++i) { + uint32_t eid = EntryID(outputs_[i]); + int* data = static_cast(data_entry_[eid]->data); + out_ptr.push_back(data); + } + for (size_t nid = 0; nid < nodes_.size(); ++nid) { + const auto& node = nodes_[nid]; + if (node.GetOpType() == "kernel") { + CHECK_EQ(node.GetOpType(), "kernel"); + auto op_name = node.GetOpName(); + if ("add" == op_name) { + auto entry = node.GetInputs()[0]; + auto shape = nodes_[entry.id_].GetOpShape()[entry.index_]; + ICHECK(add_op_ != nullptr); + add_op_(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]); + } else { + LOG(FATAL) << "Unsupported op: " << op_name; } } } - - private: - /* The verilator device handle. */ - VerilatorHandle device_{nullptr}; - /* The verilator library handle. */ - VerilatorLibrary* lib_{nullptr}; - /* The verilator vadd function handle. */ - VerilatorAddFunc vadd_func_{nullptr}; -}; - -runtime::Module VerilatorJSONRuntimeCreate(String lib_name, String symbol_name, String graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); - n->LoadLibrary(lib_name); - return runtime::Module(n); + if (prof_enable_) { + int cycles = read_(device_, prof_cycle_counter_id_, 0); + prof_->cycle_counter += cycles; + } } -TVM_REGISTER_GLOBAL("runtime.verilator_runtime_create").set_body_typed(VerilatorJSONRuntimeCreate); +TVM_REGISTER_GLOBAL("verilator.profiler_clear").set_body([](TVMArgs args, TVMRetValue* rv) { + VerilatorProfiler::ThreadLocal()->Clear(); +}); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_verilator_json") - .set_body_typed(JSONRuntimeBase::LoadFromBinary); +TVM_REGISTER_GLOBAL("verilator.profiler_status").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = VerilatorProfiler::ThreadLocal()->AsJSON(); +}); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/verilator/verilator_runtime.h b/src/runtime/contrib/verilator/verilator_runtime.h new file mode 100644 index 000000000000..acdaa3b03ce2 --- /dev/null +++ b/src/runtime/contrib/verilator/verilator_runtime.h @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/verilator/verilator_runtime.h + * \brief A runtime for Verilator. + */ + +#ifndef TVM_RUNTIME_CONTRIB_VERILATOR_VERILATOR_RUNTIME_H_ +#define TVM_RUNTIME_CONTRIB_VERILATOR_VERILATOR_RUNTIME_H_ + +#include +#include +#include + +#include +#include +#include + +#include "../../library_module.h" +#include "../json/json_node.h" +#include "../json/json_runtime.h" +#include "verilator_device.h" +#include "verilator_kernel.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime; +using namespace tvm::runtime::contrib; +using namespace tvm::runtime::json; + +typedef VerilatorHandle (*VerilatorAllocFunc)(); +typedef void (*VerilatorDeallocFunc)(VerilatorHandle); +typedef void (*VerilatorResetFunc)(VerilatorHandle, int); +typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int); +typedef int (*VerilatorReadFunc)(VerilatorHandle, int, int); + +class VerilatorLibrary : public Library { + public: + ~VerilatorLibrary(); + + /*! \brief load library */ + void Load(const std::string& name); + + /*! \brief get symbol from libray */ + void* GetSymbol(const char* name) final; + + private: + /*! \brief the library handle */ + void* lib_handle_{nullptr}; +}; + +class VerilatorProfiler { + public: + /*! \brief the number of cycle counter */ + uint32_t cycle_counter{0}; + + /*! \brief clear the profiler */ + void Clear(); + + /*! \brief get profiler data */ + std::string AsJSON(); + + /*! \brief profiler constructor */ + static VerilatorProfiler* ThreadLocal(); +}; + +class VerilatorRuntime : public JSONRuntimeBase { + public: + VerilatorRuntime(const std::string& symbol_name, const std::string& graph_json, + const Array const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + + ~VerilatorRuntime(); + + const char* type_key() const { return "verilator"; } + + /*! \brief set verilator library */ + void SetLibrary(const std::string& lib_name); + + /*! \brief set the number of reset cycles */ + void SetResetCycles(const int cycles); + + /*! \brief enable profiler */ + void EnableProfiler(); + + /*! \brief set cycle counter register id */ + void SetProfilerCycleCounterId(const int id); + + /*! \brief init verilator runtime */ + void Init(const Array& consts) override; + + /*! \brief run verilator runtime */ + void Run() override; + + private: + /*! \brief the verilator library path */ + String lib_path_; + /*! \brief the verilator device */ + VerilatorHandle device_{nullptr}; + /*! \brief the verilator library */ + VerilatorLibrary* lib_{nullptr}; + /*! \brief the verilator profiler */ + VerilatorProfiler* prof_{nullptr}; + /*! \brief the verilator read function */ + VerilatorReadFunc read_{nullptr}; + /*! \brief the verilator add op function */ + VerilatorAddFunc add_op_{nullptr}; + /*! \brief the verilator reset cycles */ + int reset_cycles_{1}; + /*! \brief the verilator profiler status */ + bool prof_enable_{false}; + /*! \brief the verilator profiler cycle counter id */ + int prof_cycle_counter_id_{0}; +}; + +} // namespace contrib +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_CONTRIB_VERILATOR_VERILATOR_RUNTIME_H_ diff --git a/tests/python/contrib/test_verilator/infrastructure.py b/tests/python/contrib/test_verilator/infrastructure.py index e8fd943aa8a0..7e4c297853d5 100644 --- a/tests/python/contrib/test_verilator/infrastructure.py +++ b/tests/python/contrib/test_verilator/infrastructure.py @@ -102,9 +102,9 @@ def compile_module(mod): if not os.path.isfile(lib): compile_hardware() - with tvm.transform.PassContext( - opt_level=3, config={"relay.ext.verilator.options": {"lib": lib}} - ): + opts = {"lib_path": lib} + + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.verilator.options": opts}): exe = relay.vm.compile(mod, target="llvm", params=None) code, lib = exe.save() return runtime.vm.Executable.load_exec(code, lib)