Skip to content

Commit

Permalink
[BYOC][Verilator] Refactor Verilator runtime (#7406)
Browse files Browse the repository at this point in the history
* new experiment

* save

* refactor

* refactor library

* add profiler

* refactor

* refactor

* add docs

* update comment

* add deallocator
  • Loading branch information
vegaluisjose authored Feb 16, 2021
1 parent 2264206 commit fc48514
Show file tree
Hide file tree
Showing 5 changed files with 307 additions and 129 deletions.
56 changes: 34 additions & 22 deletions src/relay/backend/contrib/verilator/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <sstream>

#include "../../../../runtime/contrib/json/json_node.h"
#include "../../../../runtime/contrib/verilator/verilator_runtime.h"
#include "../../utils.h"
#include "../codegen_json/codegen_json.h"

Expand Down Expand Up @@ -75,29 +76,34 @@ class VerilatorJSONSerializer : public backend::contrib::JSONSerializer {
}
};

/*! \brief Attributes to store the compiler options for Verilator */
struct VerilatorCompilerConfigNode : public tvm::AttrsNode<VerilatorCompilerConfigNode> {
String lib;

TVM_DECLARE_ATTRS(VerilatorCompilerConfigNode, "ext.attrs.VerilatorCompilerConfigNode") {
TVM_ATTR_FIELD(lib).set_default("libverilator.so");
/*! \brief Attributes to store options for Verilator */
struct VerilatorOptionsNode : public tvm::AttrsNode<VerilatorOptionsNode> {
String lib_path;
int reset_cycles;
bool profiler_enable;
int profiler_cycle_counter_id;

TVM_DECLARE_ATTRS(VerilatorOptionsNode, "ext.attrs.VerilatorOptionsNode") {
TVM_ATTR_FIELD(lib_path).describe("the design library path").set_default("libverilator.so");
TVM_ATTR_FIELD(reset_cycles).describe("the number of reset cycles").set_default(1);
TVM_ATTR_FIELD(profiler_enable).describe("enable profiler").set_default(false);
TVM_ATTR_FIELD(profiler_cycle_counter_id).describe("profiler cycle counter id").set_default(0);
}
};

class VerilatorCompilerConfig : public Attrs {
class VerilatorOptions : public Attrs {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VerilatorCompilerConfig, Attrs,
VerilatorCompilerConfigNode);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VerilatorOptions, Attrs, VerilatorOptionsNode);
};

TVM_REGISTER_NODE_TYPE(VerilatorCompilerConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.verilator.options", VerilatorCompilerConfig);
TVM_REGISTER_NODE_TYPE(VerilatorOptionsNode);
TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.verilator.options", VerilatorOptions);

/*!
* \brief The external compiler/codegen tool. It takes a Relay expression/module and
* compile it into a runtime module.
* \brief The Verilator codegen tool. It takes a Relay expression/module and
* compile it into a Verilator runtime module.
*/
runtime::Module VerilatorCompiler(const ObjectRef& ref) {
runtime::Module VerilatorBackend(const ObjectRef& ref) {
CHECK(ref->IsInstance<FunctionNode>());
auto func = Downcast<Function>(ref);
auto func_name = GetExtSymbol(func);
Expand All @@ -106,22 +112,28 @@ runtime::Module VerilatorCompiler(const ObjectRef& ref) {
std::string graph_json = serializer.GetJSON();
auto params = serializer.GetParams();

// Create runtime object
auto n = make_object<runtime::contrib::VerilatorRuntime>(func_name, graph_json, params);

// Get Verilator compiler options
auto ctx = transform::PassContext::Current();
auto cfg = ctx->GetConfig<VerilatorCompilerConfig>("relay.ext.verilator.options");
auto cfg = ctx->GetConfig<VerilatorOptions>("relay.ext.verilator.options");
if (!cfg.defined()) {
cfg = AttrsWithDefaultValues<VerilatorCompilerConfig>();
cfg = AttrsWithDefaultValues<VerilatorOptions>();
}

auto lib_name = cfg.value()->lib;
n->SetLibrary(cfg.value()->lib_path);
n->SetResetCycles(cfg.value()->reset_cycles);

if (cfg.value()->profiler_enable) {
n->EnableProfiler();
n->SetProfilerCycleCounterId(cfg.value()->profiler_cycle_counter_id);
}

const auto* pf = runtime::Registry::Get("runtime.verilator_runtime_create");
CHECK(pf != nullptr) << "Cannot find JSON runtime module to create";
auto mod = (*pf)(lib_name, func_name, graph_json, params);
return mod;
return runtime::Module(n);
}

TVM_REGISTER_GLOBAL("relay.ext.verilator").set_body_typed(VerilatorCompiler);
TVM_REGISTER_GLOBAL("relay.ext.verilator").set_body_typed(VerilatorBackend);

} // namespace contrib
} // namespace relay
Expand Down
39 changes: 33 additions & 6 deletions src/runtime/contrib/verilator/verilator_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,51 @@ namespace tvm {
namespace runtime {
namespace contrib {

/*! \brief Verilator device resource context */
typedef void* VerilatorHandle;

/* allocate Verilator object */
/*!
* \brief Allocate a verilator device resource handle
* \return The verilator device handle.
*/
extern "C" TVM_DLL VerilatorHandle VerilatorAlloc();

/* deallocate Verilator object */
/*!
* \brief Free a verilator device handle
* \param handle The verilator device handle to be freed.
*/
extern "C" TVM_DLL void VerilatorDealloc(VerilatorHandle handle);

/* read Verilator register or memory */
/*!
* \brief Read verilator register or memory
* \param handle The verilator device handle.
* \param id The register or memory identifier.
* \param addr The register or memory address (word-level).
* \return The value of register or memory.
*/
extern "C" TVM_DLL int VerilatorRead(VerilatorHandle handle, int id, int addr);

/* write Verilator register or memory */
/*!
* \brief Write verilator register or memory
* \param handle The verilator device handle.
* \param id The register or memory identifier.
* \param addr The register or memory address (word-level).
* \param value The value of register or memory.
*/
extern "C" TVM_DLL void VerilatorWrite(VerilatorHandle handle, int id, int addr, int value);

/* reset Verilator for n clock cycles */
/*!
* \brief Reset Verilator for n clock cycles
* \param handle The verilator device handle.
* \param n The number of reset cycles.
*/
extern "C" TVM_DLL void VerilatorReset(VerilatorHandle handle, int n);

/* run Verilator for n clock cycles */
/*!
* \brief Run Verilator for n clock cycles
* \param handle The verilator device handle.
* \param n The number of run cycles.
*/
extern "C" TVM_DLL void VerilatorRun(VerilatorHandle handle, int n);

} // namespace contrib
Expand Down
197 changes: 99 additions & 98 deletions src/runtime/contrib/verilator/verilator_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

/*!
* \file src/runtime/contrib/verilator/verilator_runtime.cc
* \brief A simple JSON runtime for Verilator.
* \brief A runtime for Verilator.
*/

#include "verilator_runtime.h"

#include <dlfcn.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/registry.h>
Expand All @@ -40,124 +42,123 @@ namespace tvm {
namespace runtime {
namespace contrib {

typedef VerilatorHandle (*VerilatorAllocFunc)();
typedef void (*VerilatorResetFunc)(VerilatorHandle, int);
typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int);

using namespace tvm::runtime;
using namespace tvm::runtime::contrib;
using namespace tvm::runtime::json;

class VerilatorLibrary : public Library {
public:
~VerilatorLibrary() {
if (lib_handle_) Unload();
}
void Init(const std::string& name) { Load(name); }

void* GetSymbol(const char* name) final { return GetSymbol_(name); }

private:
// Library handle
void* lib_handle_{nullptr};
// load the library
void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
ICHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name << " " << dlerror();
}

void* GetSymbol_(const char* name) { return dlsym(lib_handle_, name); }

void Unload() {
VerilatorLibrary::~VerilatorLibrary() {
if (lib_handle_) {
dlclose(lib_handle_);
lib_handle_ = nullptr;
}
};
}

class VerilatorJSONRuntime : public JSONRuntimeBase {
public:
VerilatorJSONRuntime(const std::string& symbol_name, const std::string& graph_json,
const Array<String> const_names)
: JSONRuntimeBase(symbol_name, graph_json, const_names) {}
void VerilatorLibrary::Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " "
<< dlerror();
}

const char* type_key() const { return "verilator_json"; }
void* VerilatorLibrary::GetSymbol(const char* name) { return dlsym(lib_handle_, name); }

void LoadLibrary(const std::string& lib_name) {
lib_ = new VerilatorLibrary();
lib_->Init(lib_name);
}
void VerilatorProfiler::Clear() { cycle_counter = 0; }

void Init(const Array<NDArray>& consts) override {
// get symbols
auto alloc_func = reinterpret_cast<VerilatorAllocFunc>(lib_->GetSymbol("VerilatorAlloc"));
ICHECK(alloc_func != nullptr);
auto reset_func = reinterpret_cast<VerilatorResetFunc>(lib_->GetSymbol("VerilatorReset"));
ICHECK(reset_func != nullptr);
vadd_func_ = reinterpret_cast<VerilatorAddFunc>(lib_->GetSymbol("verilator_add"));
ICHECK(vadd_func_ != nullptr);
std::string VerilatorProfiler::AsJSON() {
std::ostringstream os;
os << "{\n"
<< " \"cycle_counter\":" << cycle_counter << "\n"
<< "}\n";
return os.str();
}

// alloc device
device_ = (*alloc_func)();
VerilatorProfiler* VerilatorProfiler::ThreadLocal() {
static thread_local VerilatorProfiler inst;
return &inst;
}

// reset for 10 cycles
(*reset_func)(device_, 10);
VerilatorRuntime::~VerilatorRuntime() {
auto dealloc = reinterpret_cast<VerilatorDeallocFunc>(lib_->GetSymbol("VerilatorDealloc"));
ICHECK(dealloc != nullptr);
dealloc(device_);
lib_->~VerilatorLibrary();
}

CHECK_EQ(consts.size(), const_idx_.size())
<< "The number of input constants must match the number of required.";
void VerilatorRuntime::SetLibrary(const std::string& lib_path) { lib_path_ = lib_path; }

// Setup constants entries for weights.
SetupConstants(consts);
}
void VerilatorRuntime::SetResetCycles(const int cycles) { reset_cycles_ = cycles; }

void Run() override {
std::vector<int*> in_ptr;
std::vector<int*> out_ptr;
for (size_t i = 0; i < input_nodes_.size(); ++i) {
uint32_t eid = EntryID(input_nodes_[i], 0);
int* data = static_cast<int*>(data_entry_[eid]->data);
in_ptr.push_back(data);
}
for (size_t i = 0; i < outputs_.size(); ++i) {
uint32_t eid = EntryID(outputs_[i]);
int* data = static_cast<int*>(data_entry_[eid]->data);
out_ptr.push_back(data);
}
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
const auto& node = nodes_[nid];
if (node.GetOpType() == "kernel") {
CHECK_EQ(node.GetOpType(), "kernel");
auto op_name = node.GetOpName();
if ("add" == op_name) {
auto entry = node.GetInputs()[0];
auto shape = nodes_[entry.id_].GetOpShape()[entry.index_];
(*vadd_func_)(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
void VerilatorRuntime::EnableProfiler() { prof_enable_ = true; }

void VerilatorRuntime::SetProfilerCycleCounterId(const int id) { prof_cycle_counter_id_ = id; }

void VerilatorRuntime::Init(const Array<NDArray>& consts) {
lib_ = new VerilatorLibrary();
lib_->Load(lib_path_);
auto alloc = reinterpret_cast<VerilatorAllocFunc>(lib_->GetSymbol("VerilatorAlloc"));
ICHECK(alloc != nullptr);
auto reset = reinterpret_cast<VerilatorResetFunc>(lib_->GetSymbol("VerilatorReset"));
ICHECK(reset != nullptr);
read_ = reinterpret_cast<VerilatorReadFunc>(lib_->GetSymbol("VerilatorRead"));
ICHECK(read_ != nullptr);
add_op_ = reinterpret_cast<VerilatorAddFunc>(lib_->GetSymbol("verilator_add"));

// alloc verilator device
device_ = alloc();

// enable profiler
if (prof_enable_) prof_ = VerilatorProfiler::ThreadLocal();

// reset verilator device
reset(device_, reset_cycles_);

CHECK_EQ(consts.size(), const_idx_.size())
<< "The number of input constants must match the number of required.";

// Setup constants entries for weights.
SetupConstants(consts);
}

void VerilatorRuntime::Run() {
std::vector<int*> in_ptr;
std::vector<int*> out_ptr;
for (size_t i = 0; i < input_nodes_.size(); ++i) {
uint32_t eid = EntryID(input_nodes_[i], 0);
int* data = static_cast<int*>(data_entry_[eid]->data);
in_ptr.push_back(data);
}
for (size_t i = 0; i < outputs_.size(); ++i) {
uint32_t eid = EntryID(outputs_[i]);
int* data = static_cast<int*>(data_entry_[eid]->data);
out_ptr.push_back(data);
}
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
const auto& node = nodes_[nid];
if (node.GetOpType() == "kernel") {
CHECK_EQ(node.GetOpType(), "kernel");
auto op_name = node.GetOpName();
if ("add" == op_name) {
auto entry = node.GetInputs()[0];
auto shape = nodes_[entry.id_].GetOpShape()[entry.index_];
ICHECK(add_op_ != nullptr);
add_op_(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
}
}

private:
/* The verilator device handle. */
VerilatorHandle device_{nullptr};
/* The verilator library handle. */
VerilatorLibrary* lib_{nullptr};
/* The verilator vadd function handle. */
VerilatorAddFunc vadd_func_{nullptr};
};

runtime::Module VerilatorJSONRuntimeCreate(String lib_name, String symbol_name, String graph_json,
const Array<String>& const_names) {
auto n = make_object<VerilatorJSONRuntime>(symbol_name, graph_json, const_names);
n->LoadLibrary(lib_name);
return runtime::Module(n);
if (prof_enable_) {
int cycles = read_(device_, prof_cycle_counter_id_, 0);
prof_->cycle_counter += cycles;
}
}

TVM_REGISTER_GLOBAL("runtime.verilator_runtime_create").set_body_typed(VerilatorJSONRuntimeCreate);
TVM_REGISTER_GLOBAL("verilator.profiler_clear").set_body([](TVMArgs args, TVMRetValue* rv) {
VerilatorProfiler::ThreadLocal()->Clear();
});

TVM_REGISTER_GLOBAL("runtime.module.loadbinary_verilator_json")
.set_body_typed(JSONRuntimeBase::LoadFromBinary<VerilatorJSONRuntime>);
TVM_REGISTER_GLOBAL("verilator.profiler_status").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = VerilatorProfiler::ThreadLocal()->AsJSON();
});

} // namespace contrib
} // namespace runtime
Expand Down
Loading

0 comments on commit fc48514

Please sign in to comment.