Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][IR] Move to runtime::String #5276

Merged
merged 4 commits into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,12 @@ class PrimExpr : public BaseExpr {
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(float value); // NOLINT(*)

/*!
* \brief construct from string.
* \param str The value to be constructed.
* \brief construct from runtime String.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(std::string str); // NOLINT(*)
TVM_DLL PrimExpr(runtime::String value); // NOLINT(*)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simply remove this constructor and explicily construct StringImm in the places that we need them

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I consider that as well, but it seems there are too many places using PrimExpr. Need sometime to see which one may take string.

Copy link
Member Author

@zhiics zhiics Apr 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering the number of files changed so far, I'd like to keep this for now. It seems a bit more work. I can send a separate PR to handle it.


/*! \return the data type of this expression. */
DataType dtype() const {
Expand Down
11 changes: 6 additions & 5 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#define TVM_IR_TRANSFORM_H_

#include <tvm/support/with.h>
#include <tvm/runtime/container.h>
#include <tvm/node/container.h>
#include <tvm/ir/error.h>
#include <tvm/ir/module.h>
Expand Down Expand Up @@ -95,9 +96,9 @@ class PassContextNode : public Object {
int fallback_device{static_cast<int>(kDLCPU)};

/*! \brief The list of required passes. */
Array<PrimExpr> required_pass;
Array<runtime::String> required_pass;
/*! \brief The list of disabled passes. */
Array<PrimExpr> disabled_pass;
Array<runtime::String> disabled_pass;

TraceFunc trace_func;

Expand Down Expand Up @@ -197,7 +198,7 @@ class PassInfoNode : public Object {
std::string name;

/*! \brief The passes that are required to perform the current pass. */
Array<PrimExpr> required;
Array<runtime::String> required;

PassInfoNode() = default;

Expand Down Expand Up @@ -226,7 +227,7 @@ class PassInfo : public ObjectRef {
*/
TVM_DLL PassInfo(int opt_level,
std::string name,
Array<PrimExpr> required);
Array<runtime::String> required);

TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};
Expand Down Expand Up @@ -346,7 +347,7 @@ Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const Array<PrimExpr>& required);
const Array<runtime::String>& required);

} // namespace transform
} // namespace tvm
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/node/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

namespace tvm {

using runtime::String;
using runtime::StringObj;
using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/node/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#define TVM_NODE_NODE_H_

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h>
Expand Down Expand Up @@ -62,6 +63,7 @@ using runtime::make_object;
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::String;

} // namespace tvm
#endif // TVM_NODE_NODE_H_
5 changes: 3 additions & 2 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_

#include <tvm/runtime/container.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/ir/transform.h>
#include <tvm/relay/expr.h>
Expand Down Expand Up @@ -59,7 +60,7 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
Function(Function, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);
const tvm::Array<runtime::String>& required);

/*! \brief Remove expressions which does not effect the program result.
*
Expand Down Expand Up @@ -355,7 +356,7 @@ TVM_DLL Pass Inline();
*
* \return The pass.
*/
TVM_DLL Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions);
TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);

} // namespace transform

Expand Down
10 changes: 9 additions & 1 deletion include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,15 @@ class String : public ObjectRef {
* \note If user passes const reference, it will trigger copy. If it's rvalue,
* it will be moved into other.
*/
explicit String(std::string other);
String(std::string other); // NOLINT(*)

/*!
* \brief Construct a new String object
*
* \param other a char array.
*/
String(const char* other) // NOLINT(*)
: String(std::string(other)) {}

/*!
* \brief Change the value the reference object points to.
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ class TargetNode : public Object {
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*! \brief Keys for this target */
Array<PrimExpr> keys_array;
Array<runtime::String> keys_array;
/*! \brief Options for this target */
Array<PrimExpr> options_array;
Array<runtime::String> options_array;
/*! \brief Collection of imported libs */
Array<PrimExpr> libs_array;
Array<runtime::String> libs_array;

/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,15 +326,15 @@ class StmtExprMutator :
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of StringImm.
* \param only_enable List of runtime::String.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
TVM_DLL Stmt IRTransform(Stmt node,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
const Array<PrimExpr>& only_enable = {});
const Array<runtime::String>& only_enable = {});

/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);
const tvm::Array<runtime::String>& required);

/*!
* \brief Transform the high-level PrimFunc to a low-level version
Expand Down Expand Up @@ -100,7 +100,7 @@ TVM_DLL Pass MakePackedAPI(int num_unpacked_args);
*
* \return The pass.
*/
TVM_DLL Pass RemapThreadAxis(Map<PrimExpr, IterVar> axis_map);
TVM_DLL Pass RemapThreadAxis(Map<runtime::String, IterVar> axis_map);


/*!
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np

from tvm import target as _target
from tvm import runtime
from tvm.ir import container
from tvm.tir import expr
from tvm.te import tensor, placeholder
Expand Down Expand Up @@ -55,6 +56,8 @@ def _encode(x):
return x
if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
return x.value
if isinstance(x, runtime.container.String):
return str(x)
if x is None:
return None
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ def codegen(self, func):
lowered_func = self._get_irmodule()
param_names = self._list_params_name()
params = {}
for name in param_names:
key = name.value
for key in param_names:
arr = self._get_param_by_name(key)
param = empty(arr.shape, dtype=arr.dtype, ctx=arr.ctx)
arr.copyto(param)
Expand Down
63 changes: 54 additions & 9 deletions python/tvm/runtime/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
# under the License.
"""Runtime container structures."""
import tvm._ffi

from tvm._ffi.base import string_types
from tvm.runtime import Object, ObjectTypes
from tvm.runtime import _ffi_api

def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
Expand Down Expand Up @@ -75,18 +76,19 @@ def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)
self.__init_handle_by_constructor__(_ffi_api.ADT, tag,
*fields)

@property
def tag(self):
return _GetADTTag(self)
return _ffi_api.GetADTTag(self)

def __getitem__(self, idx):
return getitem_helper(
self, _GetADTFields, len(self), idx)
self, _ffi_api.GetADTFields, len(self), idx)

def __len__(self):
return _GetADTSize(self)
return _ffi_api.GetADTSize(self)


def tuple_object(fields=None):
Expand All @@ -106,7 +108,7 @@ def tuple_object(fields=None):
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)
return _ffi_api.Tuple(*fields)


@tvm._ffi.register_object("runtime.String")
Expand All @@ -115,7 +117,7 @@ class String(Object):

Parameters
----------
string : Str
string : str
The string used to construct a runtime String object

Returns
Expand All @@ -124,7 +126,50 @@ class String(Object):
The created object.
"""
def __init__(self, string):
self.__init_handle_by_constructor__(_String, string)
self.__init_handle_by_constructor__(_ffi_api.String, string)

def __str__(self):
return _ffi_api.GetStdString(self)

def __len__(self):
return _ffi_api.GetStringSize(self)

def __hash__(self):
return _ffi_api.StringHash(self)

def __eq__(self, other):
if isinstance(other, string_types):
return self.__str__() == other

if not isinstance(other, String):
return False

return _ffi_api.CompareString(self, other) == 0

def __ne__(self, other):
return not self.__eq__(other)

def __gt__(self, other):
return _ffi_api.CompareString(self, other) > 0

def __lt__(self, other):
return _ffi_api.CompareString(self, other) < 0

def __getitem__(self, key):
return self.__str__()[key]

def startswith(self, string):
"""Check if the runtime string starts with a given string

Parameters
----------
string : str
The provided string

tvm._ffi._init_api("tvm.runtime.container")
Returns
-------
ret : boolean
Return true if the runtime string starts with the given string,
otherwise, false.
"""
return self.__str__().startswith(string)
4 changes: 2 additions & 2 deletions python/tvm/runtime/object_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from numbers import Number, Integral
from tvm._ffi.base import string_types

from . import _ffi_node_api
from . import _ffi_node_api, _ffi_api
from .object import ObjectBase, _set_class_object_generic
from .ndarray import NDArrayBase
from .packed_func import PackedFuncBase, convert_to_tvm_func
Expand Down Expand Up @@ -56,7 +56,7 @@ def convert_to_object(value):
if isinstance(value, Number):
return const(value)
if isinstance(value, string_types):
return _ffi_node_api.String(value)
return _ffi_api.String(value)
if isinstance(value, (list, tuple)):
value = [convert_to_object(x) for x in value]
return _ffi_node_api.Array(*value)
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,26 @@ def __new__(cls):
@property
def keys(self):
if not self._keys:
self._keys = [k.value for k in self.keys_array]
self._keys = [str(k) for k in self.keys_array]
return self._keys

@property
def options(self):
if not self._options:
self._options = [o.value for o in self.options_array]
self._options = [str(o) for o in self.options_array]
return self._options

@property
def libs(self):
if not self._libs:
self._libs = [l.value for l in self.libs_array]
self._libs = [str(l) for l in self.libs_array]
return self._libs

@property
def model(self):
for opt in self.options_array:
if opt.value.startswith('-model='):
return opt.value[7:]
if opt.startswith('-model='):
return opt[7:]
return 'unknown'

@property
Expand Down
8 changes: 4 additions & 4 deletions src/autotvm/touch_extractor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
for (auto var : vars) {
Array<Array<PrimExpr> > feature_row;
ItervarFeature &fea = touch_analyzer.itervar_map[var];
feature_row.push_back(Array<PrimExpr>{std::string("_itervar_"), var});
feature_row.push_back(Array<PrimExpr>{tvm::tir::StringImmNode::make("_itervar_"), var});

Array<PrimExpr> attr{std::string("_attr_"),
Array<PrimExpr> attr{tvm::tir::StringImmNode::make("_attr_"),
FloatImm(DataType::Float(32), trans(fea.length)),
IntImm(DataType::Int(32), fea.nest_level),
FloatImm(DataType::Float(32), trans(fea.topdown_product)),
Expand All @@ -267,7 +267,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
feature_row.push_back(attr);

// arithmetic
feature_row.push_back(Array<PrimExpr>{std::string("_arith_"),
feature_row.push_back(Array<PrimExpr>{tvm::tir::StringImmNode::make("_arith_"),
FloatImm(DataType::Float(32), trans(fea.add_ct)),
FloatImm(DataType::Float(32), trans(fea.mul_ct)),
FloatImm(DataType::Float(32), trans(fea.div_ct)),
Expand All @@ -282,7 +282,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
for (auto k : bufs) {
TouchPattern &v = fea.touch_feature[k];
feature_row.push_back(
Array<PrimExpr>{k,
Array<PrimExpr>{tvm::tir::StringImmNode::make(k),
FloatImm(DataType::Float(32), trans(v.stride)),
FloatImm(DataType::Float(32), trans(v.mod)),
FloatImm(DataType::Float(32), trans(v.count)),
Expand Down
Loading