Skip to content

Commit

Permalink
[PYTHON] Check in a symbolic construction interface in python, start … (
Browse files Browse the repository at this point in the history
#4)

* [PYTHON] Check in a symbolic construction interface in python, start add graph API

* Graph API
  • Loading branch information
tqchen authored and sergei-mironov committed Aug 8, 2018
1 parent 2fe3d48 commit 2a907ab
Show file tree
Hide file tree
Showing 17 changed files with 1,143 additions and 16 deletions.
74 changes: 70 additions & 4 deletions nnvm/include/nnvm/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ typedef unsigned int nn_uint;
typedef void *AtomicSymbolCreator;
/*! \brief handle to a symbol that can be bind as operator */
typedef void *SymbolHandle;
/*! \brief handle to a AtomicSymbol */
typedef void *AtomicSymbolHandle;
/*! \brief handle to Graph */
typedef void *GraphHandle;

/*!
* \brief return str message of the last error
Expand Down Expand Up @@ -71,7 +71,7 @@ NNVM_DLL int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type = NULL);
const char **return_type);
/*!
* \brief Create an AtomicSymbol functor.
* \param creator the AtomicSymbolCreator
Expand Down Expand Up @@ -123,7 +123,18 @@ NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out);
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char **out_str);

/*!
* \brief Get string attribute from symbol
* \param symbol the source symbol
* \param key The key of the symbol.
* \param out The result attribute, can be NULL if the attribute do not exist.
* \param success Whether the result is contained in out.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol,
const char* key,
const char** out,
int *success);
/*!
* \brief Set string attribute from symbol.
* NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph.
Expand Down Expand Up @@ -216,4 +227,59 @@ NNVM_DLL int NNSymbolCompose(SymbolHandle sym,
const char** keys,
SymbolHandle* args);

// Graph IR API
/*!
* \brief create a graph handle from symbol
* \param symbol The symbol representing the graph.
* \param graph The graph handle created.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph);
/*!
* \brief free the graph handle
* \param handle The handle to be freed.
*/
NNVM_DLL int NNGraphFree(GraphHandle handle);
/*!
* \brief Get a new symbol from the graph.
* \param graph The graph handle.
* \param symbol The corresponding symbol
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
/*!
* \brief Get Set a std::string typed attribute to graph.
* \param handle The graph handle.
* \param key The key to the attribute.
* \param value The value to be exposed.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphSetStrAttr(GraphHandle handle,
const char* key,
const char* value);
/*!
* \brief Get Set a std::string typed attribute from graph attribute.
* \param handle The graph handle.
* \param key The key to the attribute.
* \param out The result attribute, can be NULL if the attribute do not exist.
* \param success Whether the result is contained in out.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphGetStrAttr(SymbolHandle handle,
const char* key,
const char** out,
int *success);
/*!
* \brief Apply pass on the src graph.
* \param src The source graph handle.
* \param num_pass The number of pass to be applied.
* \param pass_names The names of the pass.
* \param dst The result graph.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphApplyPass(GraphHandle src,
nn_uint num_pass,
const char** pass_names,
GraphHandle *dst);

#endif // NNVM_C_API_H_
4 changes: 2 additions & 2 deletions nnvm/include/nnvm/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,10 @@ inline Op& Op::attr( // NOLINT(*)
vec.resize(index_ + 1,
std::make_pair(ValueType(), 0));
std::pair<ValueType, int>& p = vec[index_];
CHECK(p.second == 0 || p.first == value)
CHECK(p.second == 0)
<< "Attribute " << attr_name
<< " of operator " << this->name
<< " is already registered to a different value";
<< " is already registered.";
vec[index_] = std::make_pair(value, 1);
});
return *this;
Expand Down
9 changes: 9 additions & 0 deletions nnvm/include/nnvm/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@ class Symbol {
* \param attrs The attributes to set.
*/
void SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs);
/*!
* \brief Get attributes from the symbol.
* This only works for symbol with outputs from single operators.
* For grouped sybmbol, an error will be raised.
* \param key Key of the attribute. When key == "name", it returns the name attirbute.
* \param out the output value of the attribute.
* \return true if the attribute exists, false if the attribute do not exist.
*/
bool GetAttr(const std::string& key, std::string* out) const;
/*!
* \brief Get attribute dictionary from the symbol.
* For grouped sybmbol, an error will be raised.
Expand Down
10 changes: 10 additions & 0 deletions nnvm/python/nnvm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/usr/bin/env python
# coding: utf-8
"""NNVM python API for ease of use and help new framework establish python API. """
from __future__ import absolute_import

from . import base
from . import symbol as sym
from . import symbol

__version__ = base.__version__
62 changes: 62 additions & 0 deletions nnvm/python/nnvm/attribute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# coding: utf-8
"""Attribute scoping support for symbolic API."""
from __future__ import absolute_import

from .base import string_types

class AttrScope(object):
"""Attribute manager for scoping.
User can also inherit this object to change naming behavior.
Parameters
----------
kwargs
The attributes to set for all symbol creations in the scope.
"""
current = None

def __init__(self, **kwargs):
self._old_scope = None
for value in kwargs.values():
if not isinstance(value, string_types):
raise ValueError("Attributes need to be string")
self._attr = kwargs

def get(self, attr):
"""
Get the attribute dict given the attribute set by the symbol.
Parameters
----------
attr : dict of string to string
The attribute passed in by user during symbol creation.
Returns
-------
attr : dict of string to string
Updated attributes to add other scope related attributes.
"""
if self._attr:
ret = self._attr.copy()
if attr:
ret.update(attr)
return ret
else:
return attr

def __enter__(self):
# pylint: disable=protected-access
self._old_scope = AttrScope.current
attr = AttrScope.current._attr.copy()
attr.update(self._attr)
self._attr = attr
AttrScope.current = self
return self

def __exit__(self, ptype, value, trace):
assert self._old_scope
AttrScope.current = self._old_scope

AttrScope.current = AttrScope()

189 changes: 189 additions & 0 deletions nnvm/python/nnvm/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# coding: utf-8
# pylint: disable=invalid-name
""" ctypes library of nnvm and helper functions """
from __future__ import absolute_import

import sys
import ctypes
import numpy as np
from . import libinfo

__all__ = ['NNNetError']
#----------------------------
# library loading
#----------------------------
if sys.version_info[0] == 3:
string_types = str,
numeric_types = (float, int, np.float32, np.int32)
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8')
else:
string_types = basestring,
numeric_types = (float, int, long, np.float32, np.int32)
py_str = lambda x: x


class NNVMError(Exception):
"""Error that will be throwed by all nnvm functions"""
pass

def _load_lib():
"""Load libary by searching possible path."""
lib_path = libinfo.find_lib_path()
lib = ctypes.cdll.LoadLibrary(lib_path[0])
# DMatrix functions
lib.NNGetLastError.restype = ctypes.c_char_p
return lib

# version number
__version__ = libinfo.__version__
# library instance of nnvm
_LIB = _load_lib()

# type definitions
nn_uint = ctypes.c_uint
SymbolCreatorHandle = ctypes.c_void_p
SymbolHandle = ctypes.c_void_p
GraphHandle = ctypes.c_void_p

#----------------------------
# helper function definition
#----------------------------
def check_call(ret):
"""Check the return value of C API call
This function will raise exception when error occurs.
Wrap every API call with this function
Parameters
----------
ret : int
return value from API calls
"""
if ret != 0:
raise NNVMError(py_str(_LIB.NNGetLastError()))

def c_str(string):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string
Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return ctypes.c_char_p(string.encode('utf-8'))


def c_array(ctype, values):
"""Create ctypes array from a python array
Parameters
----------
ctype : ctypes data type
data type of the array we want to convert to
values : tuple or list
data content
Returns
-------
out : ctypes array
Created ctypes array
"""
return (ctype * len(values))(*values)

def ctypes2buffer(cptr, length):
"""Convert ctypes pointer to buffer type.
Parameters
----------
cptr : ctypes.POINTER(ctypes.c_char)
pointer to the raw memory region
length : int
the length of the buffer
Returns
-------
buffer : bytearray
The raw byte memory buffer
"""
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
raise TypeError('expected char pointer')
res = bytearray(length)
rptr = (ctypes.c_char * length).from_buffer(res)
if not ctypes.memmove(rptr, cptr, length):
raise RuntimeError('memmove failed')
return res

def ctypes2numpy_shared(cptr, shape):
"""Convert a ctypes pointer to a numpy array
The result numpy array shares the memory with the pointer
Parameters
----------
cptr : ctypes.POINTER(mx_float)
pointer to the memory region
shape : tuple
shape of target ndarray
Returns
-------
out : numpy_array
A numpy array : numpy array
"""
if not isinstance(cptr, ctypes.POINTER(mx_float)):
raise RuntimeError('expected float pointer')
size = 1
for s in shape:
size *= s
dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents))
return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape)


def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True):
"""Convert ctypes returned doc string information into parameters docstring.
num_args : nn_uint
Number of arguments.
arg_names : ctypes.POINTER(ctypes.c_char_p)
Argument names.
arg_types : ctypes.POINTER(ctypes.c_char_p)
Argument type information.
arg_descs : ctypes.POINTER(ctypes.c_char_p)
Argument description information.
remove_dup : boolean, optional
Whether remove duplication or not.
Returns
-------
docstr : str
Python docstring of parameter sections.
"""
param_keys = set()
param_str = []
for i in range(num_args.value):
key = py_str(arg_names[i])
if key in param_keys and remove_dup:
continue
param_keys.add(key)
type_info = py_str(arg_types[i])
ret = '%s : %s' % (key, type_info)
if len(arg_descs[i]) != 0:
ret += '\n ' + py_str(arg_descs[i])
param_str.append(ret)
doc_str = ('Parameters\n' +
'----------\n' +
'%s\n')
doc_str = doc_str % ('\n'.join(param_str))
return doc_str
Loading

0 comments on commit 2a907ab

Please sign in to comment.