Skip to content

Commit

Permalink
Add NameSupply and GlobalVarSupply
Browse files Browse the repository at this point in the history
  • Loading branch information
gigiblender committed Jul 12, 2022
1 parent fc419df commit a2fc6db
Show file tree
Hide file tree
Showing 41 changed files with 567 additions and 343 deletions.
5 changes: 4 additions & 1 deletion include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
GlobalVarSupply global_var_supply,
bool simple_mode = false);

/*!
Expand All @@ -121,6 +122,7 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
GlobalVarSupply global_var_supply,
bool simple_mode = false);

/*!
Expand All @@ -133,7 +135,8 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
* \return The result module.
*/
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds);
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
GlobalVarSupply global_var_supply);
/*!
* \brief Build a device and host module for a specific target from an IRModule.
* \param funcs The functions to be built.
Expand Down
61 changes: 61 additions & 0 deletions include/tvm/ir/global_var_supply.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#ifndef TVM_RELAY_BACKEND_GLOBAL_VAR_SUPPLY_H
#define TVM_RELAY_BACKEND_GLOBAL_VAR_SUPPLY_H

#include <string>
#include <unordered_map>

#include "tvm/ir/expr.h"
#include "tvm/ir/name_supply.h"

namespace tvm {

class GlobalVarSupplyNode : public Object {
public:
GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply("")) {}

explicit GlobalVarSupplyNode(NameSupply name_supply);

GlobalVar FreshGlobal(String name, bool add_prefix = true);

GlobalVar UniqueGlobalFor(const String& name, bool add_prefix = true);

void VisitAttrs(AttrVisitor* v) {
v->Visit("name_supply", &name_supply_);
}

NameSupply name_supply_;

static constexpr const char* _type_key = "GlobalVarSupply";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarSupplyNode, Object);

private:
std::unordered_map<std::string, GlobalVar> name_to_var_map_;

friend class GlobalVarSupply;
};

class GlobalVarSupply : public ObjectRef {
public:
TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply = NameSupply::NameSupplyWithPrefix(""),
std::unordered_map<std::string, GlobalVar> name_to_var_map = {});

TVM_DLL static GlobalVarSupply GlobalVarSupplyFromNameSupply(const NameSupply& name_supply);

TVM_DLL static GlobalVarSupply EmptySupply();

explicit GlobalVarSupply(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return mutable pointers to the node. */
GlobalVarSupplyNode* operator->() const {
auto* ptr = get_mutable();
ICHECK(ptr != nullptr);
return static_cast<GlobalVarSupplyNode*>(ptr);
}

TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarSupplyNode);
};

}

#endif // TVM_RELAY_BACKEND_GLOBAL_VAR_SUPPLY_H
16 changes: 8 additions & 8 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/ir/adt.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/ir/global_var_supply.h>
#include <tvm/ir/type.h>
#include <tvm/parser/source_map.h>
#include <tvm/runtime/container/array.h>
Expand Down Expand Up @@ -64,6 +65,8 @@ class IRModuleNode : public Object {
/* \brief Additional attributes storing meta-data about the module. */
DictAttrs attrs;

GlobalVarSupply global_var_supply;

/*!
* \brief Get a module attribute.
*
Expand Down Expand Up @@ -125,6 +128,7 @@ class IRModuleNode : public Object {
v->Visit("global_type_var_map_", &global_type_var_map_);
v->Visit("source_map", &source_map);
v->Visit("attrs", &attrs);
v->Visit("global_var_supply", &global_var_supply);
}

TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
Expand Down Expand Up @@ -323,14 +327,6 @@ class IRModuleNode : public Object {
/*! \brief Helper function for registering a typedef's constructors */
void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);

/*!
* \brief Returns a version of \p name which is unique amongst all function definitions in module.
*
* \param name The original name.
* \return Updated name which is unique.
*/
String GetUniqueName(const String& name);

/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
Expand Down Expand Up @@ -368,6 +364,8 @@ class IRModule : public ObjectRef {
* \param attrs The module attributes.
*/
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
GlobalVarSupply global_var_supply =
GlobalVarSupply::EmptySupply(),
Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<String> import_set = {}, parser::SourceMap map = {},
DictAttrs attrs = {});
Expand Down Expand Up @@ -413,6 +411,7 @@ class IRModule : public ObjectRef {
*/
static std::pair<IRModule, GlobalVar> FromExprInContext(
const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs = {},
GlobalVarSupply global_var_supply = GlobalVarSupply::EmptySupply(),
const Map<GlobalTypeVar, TypeData>& type_definitions = {},
std::unordered_set<String> import_set = {});

Expand All @@ -422,6 +421,7 @@ class IRModule : public ObjectRef {
*/
TVM_DLL static IRModule FromExpr(const RelayExpr& expr,
const Map<GlobalVar, BaseFunc>& global_funcs = {},
GlobalVarSupply global_var_supply = GlobalVarSupply::EmptySupply(),
const Map<GlobalTypeVar, TypeData>& type_definitions = {});

/*!
Expand Down
70 changes: 70 additions & 0 deletions include/tvm/ir/name_supply.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#ifndef TVM_RELAY_BACKEND_NAME_SUPPLY_H
#define TVM_RELAY_BACKEND_NAME_SUPPLY_H

#include "tvm/ir/expr.h"
#include <string>
#include <unordered_map>

namespace tvm {

class NameSupplyNode : public Object {
public:
NameSupplyNode() : NameSupplyNode("") {}

explicit NameSupplyNode(const String& prefix);

String FreshName(const String& name, bool add_prefix = true);

String ReserveName(const String& name, bool add_prefix = true);

bool ContainsName(const String& name, bool add_prefix = true);

void Clear();

void VisitAttrs(AttrVisitor* v) {
v->Visit("prefix", &prefix_);
}

// Prefix for all GlobalVar names. It can be empty.
std::string prefix_;

static constexpr const char* _type_key = "NameSupply";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(NameSupplyNode, Object);

private:
String prefix_module_name(const String& name);

std::string GetUniqueName(std::string name);

// Key is function_name. Value is a counter.
std::unordered_map<std::string, int> name_map;

friend class NameSupply;
};

class NameSupply : public ObjectRef {
public:
TVM_DLL NameSupply();

TVM_DLL explicit NameSupply(const String& prefix, std::unordered_map<std::string, int> name_map = {});

TVM_DLL static NameSupply NameSupplyWithPrefix(const String& prefix = "");

TVM_DLL static NameSupply EmptySupply();

explicit NameSupply(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return mutable pointers to the node. */
NameSupplyNode* operator->() const {
auto* ptr = get_mutable();
ICHECK(ptr != nullptr);
return static_cast<NameSupplyNode*>(ptr);
}

TVM_DEFINE_OBJECT_REF_COW_METHOD(NameSupplyNode);
};

}

#endif // TVM_NAME_SUPPLY_H
1 change: 1 addition & 0 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
* @return The object representing the result.
*/
ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
GlobalVarSupply global_var_supply,
std::unordered_set<String> import_set, Device device, Target target,
Map<String, ObjectRef> attrs = {});

Expand Down
12 changes: 8 additions & 4 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""IRModule that holds the functions and type definitions."""
from tvm._ffi.base import string_types
import tvm._ffi
from tvm.ir.supply import GlobalVarSupply

from .base import Node
from . import expr as _expr
Expand All @@ -36,7 +37,7 @@ class IRModule(Node):
Map of global var to BaseFunc
"""

def __init__(self, functions=None, type_definitions=None):
def __init__(self, functions=None, type_definitions=None, globar_var_supply=None):
if functions is None:
functions = {}
elif isinstance(functions, dict):
Expand All @@ -59,7 +60,9 @@ def __init__(self, functions=None, type_definitions=None):
raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
mapped_type_defs[k] = v
type_definitions = mapped_type_defs
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)
if globar_var_supply is None:
globar_var_supply = GlobalVarSupply()
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions, globar_var_supply)

def __setitem__(self, var, val):
"""Add a mapping to the module.
Expand Down Expand Up @@ -217,7 +220,7 @@ def get_type(self, name):
return tuple([ty_var] + list(ty_data.constructors))

@staticmethod
def from_expr(expr, functions=None, type_defs=None):
def from_expr(expr, functions=None, type_defs=None, global_var_supply=None):
"""Construct a module from a standalone expression.
Parameters
Expand All @@ -238,9 +241,10 @@ def from_expr(expr, functions=None, type_defs=None):
where expr is set as the entry point
(wrapped in a function if necessary)
"""
global_var_supply = global_var_supply if global_var_supply is not None else GlobalVarSupply()
funcs = functions if functions is not None else {}
defs = type_defs if type_defs is not None else {}
return _ffi_api.Module_FromExpr(expr, funcs, defs)
return _ffi_api.Module_FromExpr(expr, funcs, global_var_supply, defs)

def _import(self, file_to_import):
return _ffi_api.Module_Import(self, file_to_import)
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/ir/supply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import tvm
from tvm import Object
from . import _ffi_api


@tvm._ffi.register_object("NameSupply")
class NameSupply(Object):

def __init__(self, prefix=""):
self.__init_handle_by_constructor__(_ffi_api.NameSupply, prefix)

def fresh_name(self, name, add_prefix=True):
return _ffi_api.NameSupply_FreshName(self, name, add_prefix)

def reserve_name(self, name, add_prefix=True):
return _ffi_api.NameSupply_ReserveName(self, name, add_prefix)


@tvm._ffi.register_object("GlobalVarSupply")
class GlobalVarSupply(Object):

def __init__(self, name_supply=None):
name_supply = name_supply if name_supply is not None else NameSupply("")
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply, name_supply)

def fresh_global(self, name, add_prefix=True):
return _ffi_api.GlobalVarSupply_FreshGlobal(self, name, add_prefix)

def unique_global_for(self, name, add_prefix=True):
return _ffi_api.GlobalVarSupply_UniqueGlobalFor(self, name, add_prefix)
2 changes: 1 addition & 1 deletion src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1371,7 +1371,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i
auto pass_ctx = tvm::transform::PassContext::Current();

auto mod = ScheduleToModule(sch, Array<ObjectRef>{tensors.begin(), tensors.end()}, name,
std::unordered_map<te::Tensor, te::Buffer>());
std::unordered_map<te::Tensor, te::Buffer>(), GlobalVarSupply::EmptySupply());

bool disable_vectorize =
pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
Expand Down
Loading

0 comments on commit a2fc6db

Please sign in to comment.