diff --git a/nnvm/Makefile b/nnvm/Makefile index 6c98ee342e15b..78a56b8158e29 100644 --- a/nnvm/Makefile +++ b/nnvm/Makefile @@ -5,7 +5,7 @@ export CFLAGS = -std=c++11 -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loop # specify tensor path .PHONY: clean all test lint doc -all: lib/libnngraph.so lib/libnngraph.a cli_test +all: lib/libnnvm.so lib/libnnvm.a cli_test SRC = $(wildcard src/*.cc src/*/*.cc example/*.cc) ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) @@ -20,11 +20,11 @@ build/%.o: src/%.cc $(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d $(CXX) -c $(CFLAGS) -c $< -o $@ -lib/libnngraph.so: $(ALL_DEP) +lib/libnnvm.so: $(ALL_DEP) @mkdir -p $(@D) $(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS) -lib/libnngraph.a: $(ALL_DEP) +lib/libnnvm.a: $(ALL_DEP) @mkdir -p $(@D) ar crv $@ $(filter %.o, $?) @@ -32,7 +32,7 @@ cli_test: $(ALL_DEP) build/test_main.o $(CXX) $(CFLAGS) -o $@ $(filter %.o %.a, $^) $(LDFLAGS) lint: - python2 dmlc-core/scripts/lint.py nngraph cpp include src + python2 dmlc-core/scripts/lint.py nnvm cpp include src doc: doxygen docs/Doxyfile diff --git a/nnvm/include/nnvm/c_api.h b/nnvm/include/nnvm/c_api.h new file mode 100644 index 0000000000000..bade0100c59e2 --- /dev/null +++ b/nnvm/include/nnvm/c_api.h @@ -0,0 +1,219 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file c_api.h + * \brief C API of NNVM symbolic construction and pass. + * Enables construction and transformation of Graph + * in any other host languages. + */ +#ifndef NNVM_C_API_H_ +#define NNVM_C_API_H_ + +#ifdef __cplusplus +#define NNVM_EXTERN_C extern "C" +#endif + +/*! \brief NNVM_DLL prefix for windows */ +#ifdef _WIN32 +#ifdef NNVM_EXPORTS +#define NNVM_DLL NNVM_EXTERN_C __declspec(dllexport) +#else +#define NNVM_DLL NNVM_EXTERN_C __declspec(dllimport) +#endif +#else +#define NNVM_DLL NNVM_EXTERN_C +#endif + +/*! \brief manually define unsigned int */ +typedef unsigned int nn_uint; + +/*! \brief handle to a function that takes param and creates symbol */ +typedef void *AtomicSymbolCreator; +/*! \brief handle to a symbol that can be bind as operator */ +typedef void *SymbolHandle; +/*! \brief handle to a AtomicSymbol */ +typedef void *AtomicSymbolHandle; + +/*! + * \brief return str message of the last error + * all function in this file will return 0 when success + * and -1 when an error occured, + * NNGetLastError can be called to retrieve the error + * + * this function is threadsafe and can be called by different thread + * \return error info + */ +NNVM_DLL const char *NNGetLastError(); + +/*! + * \brief list all the available AtomicSymbolEntry + * \param out_size the size of returned array + * \param out_array the output AtomicSymbolCreator array + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolListAtomicSymbolCreators(nn_uint *out_size, + AtomicSymbolCreator **out_array); +/*! + * \brief Get the detailed information about atomic symbol. + * \param creator the AtomicSymbolCreator. + * \param name The returned name of the creator. + * \param description The returned description of the symbol. + * \param num_doc_args Number of arguments that contain documents. + * \param arg_names Name of the arguments of doc args + * \param arg_type_infos Type informations about the arguments. + * \param arg_descriptions Description information about the arguments. + * \param return_type Return type of the function, if any. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, + const char **name, + const char **description, + nn_uint *num_doc_args, + const char ***arg_names, + const char ***arg_type_infos, + const char ***arg_descriptions, + const char **return_type = NULL); +/*! + * \brief Create an AtomicSymbol functor. + * \param creator the AtomicSymbolCreator + * \param num_param the number of parameters + * \param keys the keys to the params + * \param vals the vals of the params + * \param out pointer to the created symbol handle + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, + nn_uint num_param, + const char **keys, + const char **vals, + SymbolHandle *out); +/*! + * \brief Create a Variable Symbol. + * \param name name of the variable + * \param out pointer to the created symbol handle + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolCreateVariable(const char *name, SymbolHandle *out); +/*! + * \brief Create a Symbol by grouping list of symbols together + * \param num_symbols number of symbols to be grouped + * \param symbols array of symbol handles + * \param out pointer to the created symbol handle + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols, + SymbolHandle *symbols, + SymbolHandle *out); +/*! + * \brief Free the symbol handle. + * \param symbol the symbol + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolFree(SymbolHandle symbol); +/*! + * \brief Copy the symbol to another handle + * \param symbol the source symbol + * \param out used to hold the result of copy + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out); +/*! + * \brief Print the content of symbol, used for debug. + * \param symbol the symbol + * \param out_str pointer to hold the output string of the printing. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char **out_str); + +/*! + * \brief Set string attribute from symbol. + * NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph. + * + * Safe recommendaton: use immutable graph + * - Only allow set attributes during creation of new symbol as optional parameter + * + * Mutable graph (be careful about the semantics): + * - Allow set attr at any point. + * - Mutating an attribute of some common node of two graphs can cause confusion from user. + * + * \param symbol the source symbol + * \param num_param Number of parameters to set. + * \param keys The keys of the attribute + * \param values The value to be set + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolSetAttrs(SymbolHandle symbol, + nn_uint num_param, + const char** keys, + const char** values); +/*! + * \brief Get all attributes from symbol, including all descendents. + * \param symbol the source symbol + * \param recursive_option 0 for recursive, 1 for shallow. + * \param out_size The number of output attributes + * \param out 2*out_size strings representing key value pairs. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, + int recursive_option, + nn_uint *out_size, + const char*** out); +/*! + * \brief List arguments in the symbol. + * \param symbol the symbol + * \param out_size output size + * \param out_str_array pointer to hold the output string array + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolListArguments(SymbolHandle symbol, + nn_uint *out_size, + const char ***out_str_array); +/*! + * \brief List returns in the symbol. + * \param symbol the symbol + * \param out_size output size + * \param out_str_array pointer to hold the output string array + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolListOutputs(SymbolHandle symbol, + nn_uint *out_size, + const char ***out_str_array); +/*! + * \brief Get a symbol that contains all the internals. + * \param symbol The symbol + * \param out The output symbol whose outputs are all the internals. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol, + SymbolHandle *out); +/*! + * \brief Get index-th outputs of the symbol. + * \param symbol The symbol + * \param index the Index of the output. + * \param out The output symbol whose outputs are the index-th symbol. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolGetOutput(SymbolHandle symbol, + nn_uint index, + SymbolHandle *out); + +/*! + * \brief Compose the symbol on other symbols. + * + * This function will change the sym hanlde. + * To achieve function apply behavior, copy the symbol first + * before apply. + * + * \param sym the symbol to apply + * \param name the name of symbol + * \param num_args number of arguments + * \param keys the key of keyword args (optional) + * \param args arguments to sym + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolCompose(SymbolHandle sym, + const char* name, + nn_uint num_args, + const char** keys, + SymbolHandle* args); + +#endif // NNVM_C_API_H_ diff --git a/nnvm/include/nnvm/op.h b/nnvm/include/nnvm/op.h index bcbd4d5cb9399..516361f208d82 100644 --- a/nnvm/include/nnvm/op.h +++ b/nnvm/include/nnvm/op.h @@ -75,7 +75,10 @@ class Op { public: /*! \brief name of the operator */ std::string name; - /*! \brief detailed description of the operator */ + /*! + * \brief detailed description of the operator + * This can be used to generate docstring automatically for the operator. + */ std::string description; /*! * \brief number of inputs to the operator, @@ -339,7 +342,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*) return *this; } -inline Op& Op::set_num_inputs(uint32_t (*fn)(const NodeAttrs&)) { // NOLINT(*) +inline Op& Op::set_num_inputs(uint32_t (*fn)(const NodeAttrs& attr)) { // NOLINT(*) this->get_num_inputs = fn; return *this; } @@ -349,7 +352,7 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*) return *this; } -inline Op& Op::set_num_outputs(uint32_t (*fn)(const NodeAttrs&)) { // NOLINT(*) +inline Op& Op::set_num_outputs(uint32_t (*fn)(const NodeAttrs& attr)) { // NOLINT(*) this->get_num_outputs = fn; return *this; } diff --git a/nnvm/include/nnvm/symbolic.h b/nnvm/include/nnvm/symbolic.h index cba8c3bf8473a..1c82b86096ca6 100644 --- a/nnvm/include/nnvm/symbolic.h +++ b/nnvm/include/nnvm/symbolic.h @@ -26,9 +26,9 @@ class Symbol { /*! \brief option passed to ListAttr */ enum ListAttrOption { /*! \brief recursively list all attributes */ - kRecursive, + kRecursive = 0, /*! \brief only list attributes in current node */ - kShallow + kShallow = 1 }; /*! \brief output entries contained in the symbol */ @@ -69,7 +69,7 @@ class Symbol { * * The rest of the symbols will remain the same name. * - * \param positional arguments + * \param args positional arguments * \param kwargs keyword arguments for the symbol * \param name name of returned symbol. */ @@ -108,8 +108,7 @@ class Symbol { * * This function mutate the node's symbol and is not recommended. * - * \param key the key of the attribute - * \param value the value of the attribute. + * \param attrs The attributes to set. */ void SetAttrs(const std::vector >& attrs); /*! @@ -119,16 +118,15 @@ class Symbol { * The name of symbol will be pre-pended to each key. * \return The created attribute. */ - std::unordered_map ListAttr(ListAttrOption option) const; + std::unordered_map ListAttrs(ListAttrOption option) const; /*! * \brief create symbolic functor(AtomicSymbol) by given operator and attributes. - * \param op_name The name of the operator. + * \param op The operator. * \param attrs The additional attributes. - * * \return Symbol that can be used to call compose further. */ - static Symbol CreateFunctor(const std::string& op_name, - const std::unordered_map& attrs); + static Symbol CreateFunctor(const Op* op, + std::unordered_map&& attrs); /*! * \brief create variable symbol node * \param name name of the variable diff --git a/nnvm/src/c_api/c_api_common.h b/nnvm/src/c_api/c_api_common.h new file mode 100644 index 0000000000000..170ceb2e58d31 --- /dev/null +++ b/nnvm/src/c_api/c_api_common.h @@ -0,0 +1,59 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file c_api_error.h + * \brief Common fields of all C APIs + */ +#ifndef NNVM_C_API_C_API_COMMON_H_ +#define NNVM_C_API_C_API_COMMON_H_ + +#include +#include +#include +#include +#include +#include + +/*! \brief macro to guard beginning and end section of all functions */ +#define API_BEGIN() try { +/*! \brief every function starts with API_BEGIN(); + and finishes with API_END() or API_END_HANDLE_ERROR */ +#define API_END() } catch(dmlc::Error &_except_) { return NNAPIHandleException(_except_); } return 0; // NOLINT(*) +/*! + * \brief every function starts with API_BEGIN(); + * and finishes with API_END() or API_END_HANDLE_ERROR + * The finally clause contains procedure to cleanup states when an error happens. + */ +#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return NNAPIHandleException(_except_); } return 0; // NOLINT(*) + + +/*! \brief entry to to easily hold returning information */ +struct NNAPIThreadLocalEntry { + /*! \brief result holder for returning string */ + std::string ret_str; + /*! \brief result holder for returning strings */ + std::vector ret_vec_str; + /*! \brief result holder for returning string pointers */ + std::vector ret_vec_charp; + /*! \brief result holder for returning handles */ + std::vector ret_handles; +}; + +/*! \brief Thread local store that can be used to hold return values. */ +typedef dmlc::ThreadLocalStore NNAPIThreadLocalStore; + +/*! + * \brief Set the last error message needed by C API + * \param msg The error message to set. + */ +void NNAPISetLastError(const char* msg); +/*! + * \brief handle exception throwed out + * \param e the exception + * \return the return value of API after exception is handled + */ +inline int NNAPIHandleException(const dmlc::Error &e) { + NNAPISetLastError(e.what()); + return -1; +} + +#endif // NNVM_C_API_C_API_COMMON_H_ diff --git a/nnvm/src/c_api/c_api_error.cc b/nnvm/src/c_api/c_api_error.cc new file mode 100644 index 0000000000000..399268667dddb --- /dev/null +++ b/nnvm/src/c_api/c_api_error.cc @@ -0,0 +1,21 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file c_api_error.cc + * \brief C error handling + */ +#include +#include "./c_api_common.h" + +struct ErrorEntry { + std::string last_error; +}; + +typedef dmlc::ThreadLocalStore NNAPIErrorStore; + +const char *NNGetLastError() { + return NNAPIErrorStore::Get()->last_error.c_str(); +} + +void NNAPISetLastError(const char* msg) { + NNAPIErrorStore::Get()->last_error = msg; +} diff --git a/nnvm/src/c_api/c_api_symbolic.cc b/nnvm/src/c_api/c_api_symbolic.cc new file mode 100644 index 0000000000000..c023ee4a81865 --- /dev/null +++ b/nnvm/src/c_api/c_api_symbolic.cc @@ -0,0 +1,222 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file c_api_symbolic.cc + * \brief C API related to symbolic graph compsition. + */ +#include +#include +#include +#include +#include "./c_api_common.h" + +using namespace nnvm; + +int NNSymbolListAtomicSymbolCreators(nn_uint *out_size, + AtomicSymbolCreator **out_array) { + API_BEGIN(); + auto &vec = dmlc::Registry::List(); + *out_size = static_cast(vec.size()); + *out_array = (AtomicSymbolCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*) + API_END(); +} + + +int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, + const char **name, + const char **description, + nn_uint *num_doc_args, + const char ***arg_names, + const char ***arg_type_infos, + const char ***arg_descriptions, + const char **return_type) { + const Op *op = static_cast(creator); + + API_BEGIN(); + *name = op->name.c_str(); + *description = op->description.c_str(); + *num_doc_args = 0; + API_END(); +} + + +int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, + nn_uint num_param, + const char **keys, + const char **vals, + SymbolHandle *out) { + Symbol *s = new Symbol(); + API_BEGIN(); + const Op* op = static_cast(creator); + std::unordered_map kwargs; + for (nn_uint i = 0; i < num_param; ++i) { + kwargs.insert({std::string(keys[i]), std::string(vals[i])}); + } + *s = Symbol::CreateFunctor(op, std::move(kwargs)); + *out = s; + API_END_HANDLE_ERROR(delete s;); +} + +int NNSymbolCreateVariable(const char *name, SymbolHandle *out) { + Symbol *s = new Symbol(); + API_BEGIN(); + *s = Symbol::CreateVariable(name); + *out = s; + API_END_HANDLE_ERROR(delete s); +} + +int NNSymbolCreateGroup(nn_uint num_symbols, + SymbolHandle *symbols, + SymbolHandle *out) { + Symbol *s = new Symbol(); + Symbol **sym_arr = (Symbol**)symbols; // NOLINT(*) + API_BEGIN(); + std::vector syms; + for (nn_uint i = 0; i < num_symbols; ++i) { + syms.push_back(*sym_arr[i]); + } + *s = Symbol::CreateGroup(syms); + *out = s; + API_END_HANDLE_ERROR(delete s); +} + +int NNSymbolGetOutput(SymbolHandle symbol, + nn_uint index, + SymbolHandle *out) { + Symbol *s = new Symbol(); + API_BEGIN(); + *s = (*static_cast(symbol))[index]; + *out = s; + API_END_HANDLE_ERROR(delete s); +} + +int NNSymbolGetInternals(SymbolHandle symbol, + SymbolHandle *out) { + Symbol *s = new Symbol(); + API_BEGIN(); + *s = static_cast(symbol)->GetInternals(); + *out = s; + API_END_HANDLE_ERROR(delete s); +} + +int NNSymbolFree(SymbolHandle symbol) { + API_BEGIN(); + delete static_cast(symbol); + API_END(); +} + +int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out) { + Symbol *s = new Symbol(); + API_BEGIN(); + *s = static_cast(symbol)->Copy(); + *out = s; + API_END_HANDLE_ERROR(delete s); +} + +int NNSymbolPrint(SymbolHandle symbol, const char **out_str) { + Symbol *s = static_cast(symbol); + NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); + API_BEGIN(); + std::ostringstream os; + s->Print(os); + ret->ret_str = os.str(); + *out_str = (ret->ret_str).c_str(); + API_END(); +} + +int MXSymbolSetAttrs(SymbolHandle symbol, + nn_uint num_param, + const char** keys, + const char** vals) { + Symbol *s = static_cast(symbol); + API_BEGIN(); + std::vector > kwargs; + for (nn_uint i = 0; i < num_param; ++i) { + kwargs.emplace_back( + std::make_pair(std::string(keys[i]), std::string(vals[i]))); + } + s->SetAttrs(kwargs); + API_END(); +} + +int NNSymbolListAttrs(SymbolHandle symbol, + int option, + nn_uint *out_size, + const char*** out) { + Symbol *s = static_cast(symbol); + NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); + API_BEGIN(); + std::unordered_map attr = + std::move(s->ListAttrs(static_cast(option))); // NOLINT(*) + + std::vector& attr_list = ret->ret_vec_str; + attr_list.clear(); + for (const auto& kv : attr) { + attr_list.push_back(kv.first); + attr_list.push_back(kv.second); + } + *out_size = attr.size(); + ret->ret_vec_charp.clear(); + for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { + ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); + } + *out = dmlc::BeginPtr(ret->ret_vec_charp); + API_END(); +} + +int NNSymbolListArguments(SymbolHandle symbol, + nn_uint *out_size, + const char ***out_str_array) { + Symbol *s = static_cast(symbol); + NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); + API_BEGIN(); + ret->ret_vec_str = std::move(s->ListArguments()); + ret->ret_vec_charp.clear(); + for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { + ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); + } + *out_size = static_cast(ret->ret_vec_charp.size()); + *out_str_array = dmlc::BeginPtr(ret->ret_vec_charp); + API_END(); +} + +int NNSymbolListOutputs(SymbolHandle symbol, + nn_uint *out_size, + const char ***out_str_array) { + Symbol *s = static_cast(symbol); + NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); + API_BEGIN(); + ret->ret_vec_str = std::move(s->ListOutputs()); + ret->ret_vec_charp.clear(); + for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { + ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); + } + *out_size = static_cast(ret->ret_vec_charp.size()); + *out_str_array = dmlc::BeginPtr(ret->ret_vec_charp); + API_END(); +} + +int NNSymbolCompose(SymbolHandle sym, + const char *name, + nn_uint num_args, + const char** keys, + SymbolHandle* args) { + API_BEGIN(); + std::string s_name; + if (name != nullptr) s_name = name; + + Symbol* s = static_cast(sym); + if (keys == nullptr && num_args != 0) { + std::vector pos_args; + for (nn_uint i = 0; i < num_args; ++i) { + pos_args.push_back(*((Symbol*)args[i])); // NOLINT(*) + } + s->Compose(pos_args, {}, s_name); + } else { + std::unordered_map kwargs; + for (nn_uint i = 0; i < num_args; ++i) { + kwargs[keys[i]] = *((Symbol*)args[i]); // NOLINT(*) + } + s->Compose({}, kwargs, s_name); + } + API_END(); +} diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc index 03a6f065b0086..a6e70b29da52f 100644 --- a/nnvm/src/core/symbolic.cc +++ b/nnvm/src/core/symbolic.cc @@ -322,7 +322,7 @@ void Symbol::SetAttrs(const std::vector >& a } } -std::unordered_map Symbol::ListAttr(ListAttrOption option) const { +std::unordered_map Symbol::ListAttrs(ListAttrOption option) const { if (option == kRecursive) { std::unordered_map ret; DFSVisit(this->outputs, [&ret](const std::shared_ptr& n) { @@ -336,12 +336,12 @@ std::unordered_map Symbol::ListAttr(ListAttrOption opt } } -Symbol Symbol::CreateFunctor(const std::string& op_name, - const std::unordered_map& attrs) { +Symbol Symbol::CreateFunctor(const Op* op, + std::unordered_map&& attrs) { Symbol s; std::shared_ptr n = Node::Create(); - n->op = Op::Get(op_name); - n->attrs.dict = attrs; + n->op = op; + n->attrs.dict = std::move(attrs); if (n->op->attr_parser != nullptr) { (*n->op->attr_parser)(&(n->attrs)); }