Skip to content

Commit

Permalink
[DOCS][FRONTEND] Modify from_mxnet to also return params, update docs (
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and sergei-mironov committed Aug 8, 2018
1 parent a3ca64b commit d2bdc89
Show file tree
Hide file tree
Showing 28 changed files with 224 additions and 62 deletions.
19 changes: 19 additions & 0 deletions nnvm/docs/api/python/compiler.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
nnvm.compiler
-------------

.. automodule:: nnvm.compiler

.. autofunction:: nnvm.compiler.build

.. autofunction:: nnvm.compiler.build_config

.. autofunction:: nnvm.compiler.optimize

.. automodule:: nnvm.compiler.graph_util
:members:

.. automodule:: nnvm.compiler.graph_attr
:members:

.. automodule:: nnvm.compiler.compile_engine
:members:
7 changes: 7 additions & 0 deletions nnvm/docs/api/python/frontend.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
nnvm.frontend
-------------

.. automodule:: nnvm.frontend


.. autofunction:: nnvm.frontend.from_mxnet
8 changes: 8 additions & 0 deletions nnvm/docs/api/python/graph.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
nnvm.graph
----------
.. automodule:: nnvm.graph

.. autofunction:: nnvm.graph.create

.. autoclass:: nnvm.graph.Graph
:members:
16 changes: 16 additions & 0 deletions nnvm/docs/api/python/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
Python API
==========

This document contains the python API to NNVM compiler toolchain.
For user


.. toctree::
:maxdepth: 2

compiler
frontend
runtime
symbol
graph
top
8 changes: 8 additions & 0 deletions nnvm/docs/api/python/runtime.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
nnvm.runtime
------------
.. automodule:: nnvm.runtime

.. autofunction:: nnvm.runtime.create

.. autoclass:: nnvm.runtime.Module
:members:
7 changes: 7 additions & 0 deletions nnvm/docs/api/python/symbol.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
nnvm.symbol
-----------
.. automodule:: nnvm.symbol

.. autoclass:: nnvm.symbol.Symbol

.. autofunction:: nnvm.symbol.Group
13 changes: 13 additions & 0 deletions nnvm/docs/api/python/top.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
nnvm.top
--------
.. automodule:: nnvm.top

.. autofunction:: register_compute

.. autofunction:: register_schedule

.. autofunction:: register_pattern


.. autoclass:: nnvm.top.AttrDict
:members:
1 change: 1 addition & 0 deletions nnvm/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ Contents
self
top
tutorials/index
api/python/index
dev/index
2 changes: 1 addition & 1 deletion nnvm/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ NNVM Examples
=============
This folder contains example snippets of running NNVM Compilation.

- See also [Tutorials](tutorials) for tutorials with detailed explainations.
- See also [Tutorials](../tutorials) for tutorials with detailed explainations.
9 changes: 5 additions & 4 deletions nnvm/python/nnvm/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
"""Namespace for NNVM-TVM compiler toolchain"""
"""NNVM compiler toolchain.
User only need to use :any:`build` and :any:`build_config` to do the compilation.
The other APIs are for more advanced interaction with the compiler toolchain.
"""
from __future__ import absolute_import

import tvm
Expand All @@ -10,9 +14,6 @@
from .. import symbol as _symbol
from .. import graph as _graph

from .registry import OpPattern
from .registry import register_compute, register_schedule, register_pattern

from .. import top as _top


Expand Down
13 changes: 11 additions & 2 deletions nnvm/python/nnvm/compiler/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def _update_shape_dtype(shape, dtype, params):
def optimize(graph, shape, dtype="float32"):
"""Perform target and parameter invariant graph optimization.
This is an advanced function that usually do not need to be called.
Call build instead.
Parameters
----------
graph : Graph
Expand All @@ -126,7 +129,11 @@ def optimize(graph, shape, dtype="float32"):
def build(graph, target, shape, dtype="float32", params=None):
"""Build graph into runtime library.
This is the final step of graph compilation.
The build function will optimize the graph and do the compilation.
When params is provided, the compiler might split the graph to
pre-compute certain values, so the final execution graph can
be different from the original one.
Parameters
----------
Expand Down Expand Up @@ -255,8 +262,10 @@ def precompute_prune(graph, params):
graph._set_json_attr("param_name_list", list(params.keys()), "list_str")
graph = graph.apply("PrecomputePrune")
pre_graph = graph_attr._move_out_graph(graph, "precompute_graph")
if not pre_graph.symbol.list_output_names():
if pre_graph is None:
return graph, params
out_names = pre_graph.json_attr("output_names")
if not pre_graph.symbol.list_output_names():
return graph, params
out_arrs = _run_graph(pre_graph, params)
return graph, dict(zip(out_names, out_arrs))
10 changes: 8 additions & 2 deletions nnvm/python/nnvm/compiler/compile_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# pylint: disable=invalid-name
"""Compiler engine interface to internal engine"""
"""Compiler engine interface to internal engine
You can get the engine singleton at ``nnvm.compiler.engine``
"""
import tvm

_list_cache_items = tvm.get_global_func("nnvm.compiler.ListCacheItems")
Expand Down Expand Up @@ -30,7 +33,10 @@ class GraphFunc(tvm.node.NodeBase):


class Engine(object):
"""Global singleton compilation engine."""
"""Global singleton compilation engine.
You can get the singleton at ``nnvm.compiler.engine``
"""
def items(self):
"""List the available cache key value pairs.
Expand Down
6 changes: 6 additions & 0 deletions nnvm/python/nnvm/compiler/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def infer_shape(graph, **shape):
graph : Graph
The graph to perform shape inference from
shape : dict of str to tuple
The specific input shape.
Returns
-------
in_shape : list of tuple
Expand All @@ -38,6 +41,9 @@ def infer_dtype(graph, **dtype):
graph : Graph
The graph to perform type inference from
dtype : dict of str to dtype
The specific input data type.
Returns
-------
in_dtype : list of tuple
Expand Down
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/frontend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Frontend package."""
"""NNVM frontends."""
from __future__ import absolute_import
from .mxnet import from_mxnet
26 changes: 22 additions & 4 deletions nnvm/python/nnvm/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""MXNet symbol frontend."""
from __future__ import absolute_import as _abs
import json
import tvm
from .. import symbol as _sym

__all__ = ['from_mxnet']
Expand Down Expand Up @@ -288,17 +289,34 @@ def _from_mxnet_impl(symbol, graph):
return node


def from_mxnet(symbol):
"""Convert from mxnet.Symbol to compatible nnvm.Symbol
def from_mxnet(symbol, arg_params=None, aux_params=None):
"""Convert from MXNet's model into compatible NNVM format.
Parameters
----------
symbol : mxnet.Symbol
MXNet symbol
arg_params : dict of str to mx.NDArray
The argument parameters in mxnet
aux_params : dict of str to mx.NDArray
The auxiliary parameters in mxnet
Returns
-------
nnvm.Symbol
net: nnvm.Symbol
Compatible nnvm symbol
params : dict of str to tvm.NDArray
The parameter dict to be used by nnvm
"""
return _from_mxnet_impl(symbol, {})
sym = _from_mxnet_impl(symbol, {})
params = {}
arg_params = arg_params if arg_params else {}
aux_params = aux_params if aux_params else {}
for k, v in arg_params.items():
params[k] = tvm.nd.array(v.asnumpy())
for k, v in aux_params.items():
params[k] = tvm.nd.array(v.asnumpy())
return sym, params
5 changes: 4 additions & 1 deletion nnvm/python/nnvm/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
"""Symbolic configuration API."""
"""NNVM Graph IR API.
This is a developer API that is used to manipulate and transform graphs.
"""
from __future__ import absolute_import as _abs

import ctypes
Expand Down
6 changes: 5 additions & 1 deletion nnvm/python/nnvm/symbol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# pylint: disable=invalid-name, unused-import
"""Symbolic configuration API."""
"""Symbolic graph construction API.
This namespace contains most of the registered operators.
For detailed list of operators, checkout ``Core Tensor Operators``
"""
from __future__ import absolute_import as _abs
import sys as _sys
import os as _os
Expand Down
8 changes: 7 additions & 1 deletion nnvm/python/nnvm/top/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""Declaration about Tensor operators"""
"""Tensor operator property registry
Provide information to lower and schedule tensor operators.
"""
from .attr_dict import AttrDict
from . import tensor
from . import nn
from . import transform
from . import reduction

from .registry import OpPattern
from .registry import register_compute, register_schedule, register_pattern
1 change: 1 addition & 0 deletions nnvm/python/nnvm/top/attr_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class AttrDict(object):
"""Attribute dictionary in nnvm.
Used by python registration of compute and schedule function.
AttrDict is passed as the first argument to schedule and compute function.
"""
_tvm_tcode = 18

Expand Down
23 changes: 20 additions & 3 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import topi
from topi.util import get_const_int
from .tensor import _fschedule_broadcast
from ..compiler import registry as reg
from ..compiler import OpPattern
from . import registry as reg
from .registry import OpPattern

# relu
@reg.register_compute("relu")
Expand Down Expand Up @@ -55,9 +55,26 @@ def schedule_softmax(_, outs, target):
# naive schedule
return tvm.create_schedule([x.op for x in outs])

# Mark softmax as extern as we do not fuse it in call cases
reg.register_pattern("softmax", OpPattern.OPAQUE)

# log softmax
@reg.register_compute("log_softmax")
def compute_log_softmax(attrs, inputs, _):
"""Compute definition of softmax"""
axis = attrs.get_int("axis")
assert axis == -1, "only support axis == -1 for now"
return topi.nn.log_softmax(inputs[0])

@reg.register_schedule("log_softmax")
def schedule_log_softmax(_, outs, target):
"""Schedule definition of softmax"""
if target == "cuda":
return topi.cuda.schedule_softmax(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])

# Mark softmax as extern as we do not fuse it in call cases
reg.register_pattern("log_softmax", OpPattern.OPAQUE)

# dense
@reg.register_compute("dense")
Expand Down
4 changes: 2 additions & 2 deletions nnvm/python/nnvm/top/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import tvm
import topi
import topi.cuda
from ..compiler import registry as reg
from ..compiler import OpPattern
from . import registry as reg
from .registry import OpPattern

def _schedule_reduce(_, outs, target):
"""Generic schedule for reduce"""
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions nnvm/python/nnvm/top/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import tvm
import topi
import topi.cuda
from ..compiler import registry as reg
from ..compiler import OpPattern
from . import registry as reg
from .registry import OpPattern

def _schedule_injective(_, outs, target):
"""Generic schedule for binary bcast"""
Expand Down
4 changes: 2 additions & 2 deletions nnvm/python/nnvm/top/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import tvm
import topi
from .tensor import _fschedule_broadcast, _fschedule_injective
from ..compiler import registry as reg
from ..compiler import OpPattern
from . import registry as reg
from .registry import OpPattern

# Need add reshape
@reg.register_compute("expand_dims")
Expand Down
9 changes: 7 additions & 2 deletions nnvm/src/compiler/packed_func_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,13 @@ TVM_REGISTER_GLOBAL("nnvm.graph._move_module")
TVM_REGISTER_GLOBAL("nnvm.graph._move_graph")
.set_body([](TVMArgs args, TVMRetValue *rv) {
const nnvm::Graph& g = args[0].AsExtension<Graph>();
*rv = const_cast<nnvm::Graph*>(&g)->
MoveCopyAttr<nnvm::Graph>(args[1]);
std::string key = args[1];
if (g.attrs.count(key)) {
*rv = const_cast<nnvm::Graph*>(&g)->
MoveCopyAttr<nnvm::Graph>(key);
} else {
*rv = nullptr;
}
});
} // namespace compiler
} // namespace nnvm
Loading

0 comments on commit d2bdc89

Please sign in to comment.