Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change target string to Target object in the TE compiler and interpreter #8835

Merged
merged 14 commits into from
Aug 31, 2021
9 changes: 4 additions & 5 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -669,11 +669,10 @@ class AOTExecutorCodegen : public ExprVisitor {
ret.lowered_funcs = lowered_module.per_target_module;
ret.external_mods = lowered_module.external_mods;

auto target_host_str = target_host_->str();
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_str]->Update(mod_run);
if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_]->Update(mod_run);
} else {
ret.lowered_funcs.Set(target_host_str, mod_run);
ret.lowered_funcs.Set(target_host_, mod_run);
}

std::vector<String> input_var_names(input_vars_.size());
Expand Down Expand Up @@ -778,7 +777,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
return (*it).second.first;
}

Map<String, IRModule> get_irmodule() { return this->output_.lowered_funcs; }
Map<Target, IRModule> get_irmodule() { return this->output_.lowered_funcs; }

std::shared_ptr<AOTExecutorCodegen> codegen_;
LoweredOutput output_;
Expand Down
17 changes: 9 additions & 8 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ struct ExecutorCodegen {
return CallFunc<Array<tvm::runtime::Module>>("get_external_modules", nullptr);
}

Map<String, IRModule> GetIRModule() {
return CallFunc<Map<String, IRModule>>("get_irmodule", nullptr);
Map<Target, IRModule> GetIRModule() {
return CallFunc<Map<Target, IRModule>>("get_irmodule", nullptr);
}

runtime::Metadata GetMetadata() { return CallFunc<runtime::Metadata>("get_metadata"); }
Expand Down Expand Up @@ -491,8 +491,9 @@ class RelayBuildModule : public runtime::ModuleNode {
auto lowered_funcs = executor_codegen_->GetIRModule();

// No need to build for external functions.
if (lowered_funcs.find("ext_dev") != lowered_funcs.end()) {
lowered_funcs.Set("ext_dev", IRModule());
Target ext_dev("ext_dev");
if (lowered_funcs.find(ext_dev) != lowered_funcs.end()) {
lowered_funcs.Set(ext_dev, IRModule());
}

// Generate a placeholder function that attaches linked params as its arguments.
Expand All @@ -510,11 +511,11 @@ class RelayBuildModule : public runtime::ModuleNode {
DictAttrs attrs{dict};
auto prim = tir::PrimFunc(Array<tir::Var>(), tir::SeqStmt(Array<tir::Stmt>()), VoidType(),
Map<tir::Var, tir::Buffer>(), attrs);
if (lowered_funcs.find(target_host->str()) == lowered_funcs.end()) {
lowered_funcs.Set(target_host->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
if (lowered_funcs.find(target_host) == lowered_funcs.end()) {
lowered_funcs.Set(target_host, IRModule(Map<GlobalVar, BaseFunc>({})));
}
lowered_funcs[target_host->str()]->Add(
GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), prim);
lowered_funcs[target_host]->Add(GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param),
prim);
}

// When there is no lowered_funcs due to reasons such as optimization.
Expand Down
25 changes: 14 additions & 11 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ namespace {
struct PairHash {
template <typename T1, typename T2>
std::size_t operator()(const std::pair<T1, T2>& k) const {
return std::hash<T1>()(k.first) ^ std::hash<T2>()(k.second);
return dmlc::HashCombine(std::hash<T1>()(k.first), std::hash<T2>()(k.second));
}
template <typename T2>
std::size_t operator()(const std::pair<Target, T2>& k) const {
return dmlc::HashCombine(ObjectHash()(k.first), std::hash<T2>()(k.second));
}
};

Expand Down Expand Up @@ -289,7 +293,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public:
// TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule.
Interpreter(IRModule mod, Map<String, IRModule> per_target_module, Device device, Target target)
Interpreter(IRModule mod, Map<Target, IRModule> per_target_module, Device device, Target target)
: mod_(mod),
per_target_module_(per_target_module),
device_(device),
Expand Down Expand Up @@ -373,7 +377,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
*/
PackedFunc TIRToPackedFunc(const GlobalVar& tir_fn_var, const Array<GlobalVar>& all_tir_fn_vars,
Target target) {
std::pair<std::string, std::string> packed_func_key(target->str(), tir_fn_var->name_hint);
std::pair<Target, std::string> packed_func_key(target, tir_fn_var->name_hint);
auto packed_itr = compiled_packed_funcs_.find(packed_func_key);
if (packed_itr != compiled_packed_funcs_.end()) {
// Already compiled.
Expand All @@ -382,7 +386,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,

// Project out just the function(s) we need.
IRModule lowered_projected_mod;
auto mod_itr = per_target_module_.find(target->str());
auto mod_itr = per_target_module_.find(target);
ICHECK(mod_itr != per_target_module_.end())
<< "No target module for target '" << target->str() << "'";
const IRModule& target_module = (*mod_itr).second;
Expand All @@ -407,7 +411,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PackedFunc packed_func = runtime_module.GetFunction(var->name_hint);
ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint
<< "' in compiled module for target '" << target->str() << "'";
compiled_packed_funcs_.emplace(std::make_pair(target->str(), var->name_hint), packed_func);
compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func);
}

// Return just what we need for this call.
Expand Down Expand Up @@ -874,11 +878,10 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
// Map from target key to lowered TIR functions derived from mod_.
// Note that primitives are implicitly executed on target_, while shape functions are implicitly
// executed on the default 'cpu' host. Thus this map has at most two entries.
Map<String, IRModule> per_target_module_;
Map<Target, IRModule> per_target_module_;
// Cached packed functions for the primitives and shape functions, keyed by target and
// global var name.
std::unordered_map<std::pair<std::string, std::string>, PackedFunc, PairHash>
compiled_packed_funcs_;
std::unordered_map<std::pair<Target, std::string>, PackedFunc, PairHash> compiled_packed_funcs_;
// Unique device on which primitives (but not shape functions) will be executed.
// (For simplicity we only run the interpreter on a single device.)
Device device_;
Expand All @@ -895,7 +898,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
* rewritten \p mod and target-specific modules containing bindings for all TIR primitive
* functions needed by the rewritten module.
*/
std::pair<IRModule, Map<String, IRModule>> Prepare(IRModule mod, Device device, Target target) {
std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device, Target target) {
// Run minimal transforms on module to establish invariants needed by interpreter.
transform::Sequential seq({transform::SimplifyInference(),
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
Expand Down Expand Up @@ -1014,7 +1017,7 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
// and can just eval it directly.
expr_to_eval = expr;
}
std::pair<IRModule, Map<String, IRModule>> main_and_lowered =
std::pair<IRModule, Map<Target, IRModule>> main_and_lowered =
Prepare(mod_with_expr, device, target);
std::shared_ptr<Interpreter> intrp = std::make_shared<Interpreter>(
/*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device,
Expand Down Expand Up @@ -1057,7 +1060,7 @@ ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set, Device device, Target target) {
std::pair<IRModule, GlobalVar> mod_and_global =
IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set);
std::pair<IRModule, Map<String, IRModule>> main_and_lowered =
std::pair<IRModule, Map<Target, IRModule>> main_and_lowered =
Prepare(mod_and_global.first, device, target);
Interpreter intrp(
/*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device,
Expand Down
22 changes: 11 additions & 11 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,30 +85,30 @@ class TECompilerImpl : public TECompilerNode {
return LowerShapeFuncInternal(key)->cached_func;
}

Map<String, IRModule> GetLoweredFunctions() {
Map<String, IRModule> lowered_functions;
Map<Target, IRModule> GetLoweredFunctions() {
Map<Target, IRModule> lowered_functions;
for (const auto& it : cache_) {
auto source_func = it.first;
auto lowered_func = it.second;
auto target = source_func->target;

if (!lowered_functions.count(target->str())) {
lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
if (!lowered_functions.count(target)) {
lowered_functions.Set(target, IRModule(Map<GlobalVar, BaseFunc>({})));
}

lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
lowered_functions[target]->Update(lowered_func->cached_func->funcs);
}

for (const auto& it : shape_func_cache_) {
auto source_func = it.first;
auto lowered_func = it.second;
auto target = source_func->target;

if (!lowered_functions.count(target->str())) {
lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
if (!lowered_functions.count(target)) {
lowered_functions.Set(target, IRModule(Map<GlobalVar, BaseFunc>({})));
}

lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
lowered_functions[target]->Update(lowered_func->cached_func->funcs);
}
return lowered_functions;
}
Expand Down Expand Up @@ -884,7 +884,7 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) {

// Annotate the per-target functions with their target and add them to the unified module
for (const auto& kv : mod.per_target_module) {
const String target = kv.first;
const Target target = kv.first;
const IRModule target_module = kv.second;

// Right now, per-target functions are TIR functions, which don't have type definitions, so
Expand Down Expand Up @@ -926,15 +926,15 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) {
main_mod->AddTypeDef(kv.first, kv.second);
}

Map<String, IRModule> per_target_modules;
Map<Target, IRModule> per_target_modules;
for (const auto& kv : mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<relay::FunctionNode>()) {
main_mod->Add(var, func);
} else if (func->IsInstance<tir::PrimFuncNode>()) {
// Extract target
Optional<String> target = func->GetAttr<String>(tvm::attr::kTarget);
Optional<Target> target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target) << "Target should be set at this point";

// Put the function in per_target_modules
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/te_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class TECompilerNode : public Object {
virtual CachedFunc Lower(const CCacheKey& key, const String mod_name) = 0;

/* Return all functions which have been lowered by the compiler, keyed by target. */
virtual Map<String, IRModule> GetLoweredFunctions() = 0;
virtual Map<Target, IRModule> GetLoweredFunctions() = 0;

/*!
* \brief Just in time compile to get a PackedFunc.
Expand Down Expand Up @@ -144,7 +144,7 @@ struct LoweredModule {
/*! \brief The module which contains the Relay code. */
IRModule main_module;
/*! \brief The module which contains per target code. */
Map<String, IRModule> per_target_module;
Map<Target, IRModule> per_target_module;
/*! \brief The external runtime modules which must be combined with the lowered code. */
Array<tvm::runtime::Module> external_mods;
// TODO(@electriclilies): THis might need to become a map
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ int64_t CalculateRelayExprSizeBytes(const Type& expr_type);
*/
struct LoweredOutput {
std::string graph_json;
Map<String, IRModule> lowered_funcs;
Map<Target, IRModule> lowered_funcs;
Array<tvm::runtime::Module> external_mods;
Map<String, FunctionInfo> function_metadata;
std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>> params;
Expand Down