Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Target] Add support for target object with host field compatible with previous api #7534

Merged
merged 74 commits into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
6401b6f
Fix legacy code on target host
zxybazh Feb 25, 2021
0167a5f
Modify legacy code for target host change
zxybazh Feb 25, 2021
2a3c502
Add tests and fix merge issue
zxybazh Feb 25, 2021
511ce56
Add condition for same host
zxybazh Feb 25, 2021
69601a7
Modify all files for new target host api compatibility
zxybazh Feb 26, 2021
23187d8
Add newline
zxybazh Feb 26, 2021
85b27db
Change import format
zxybazh Feb 26, 2021
7e4eb0a
Optimize test file
zxybazh Feb 26, 2021
59457f6
Add match error info for unit tests
zxybazh Feb 26, 2021
b7e4c71
Fix for heterogeneous targets
zxybazh Mar 2, 2021
f5ccc50
Fix format for dict iteration
zxybazh Mar 2, 2021
11c77ba
Fix target host type error
zxybazh Mar 2, 2021
ca95bfd
Merge branch 'main' of https://github.com/zxybazh/tvm into target
zxybazh Mar 2, 2021
7543422
Skip one testcase for tvm infinite loop bug
zxybazh Mar 3, 2021
fbd597a
Fixed bug for target map compatibility
zxybazh Mar 3, 2021
4d11b7b
Fix another TargetsMap issue
zxybazh Mar 3, 2021
5a0f06b
Fix typo and infinite loop error
zxybazh Mar 3, 2021
0e01e13
Temporary fix for handle issue
zxybazh Mar 3, 2021
7db8327
Fix vm target
zxybazh Mar 4, 2021
f214410
Add condition support for str case
zxybazh Mar 4, 2021
38c4ec0
Add GetHost function and fix previous bugs
zxybazh Mar 4, 2021
8bacc8d
Fix measure_record.cc
zxybazh Mar 4, 2021
36153dd
Fix search_task.cc
zxybazh Mar 4, 2021
df1f6a1
Fix compiler.cc, memory_alloc.cc
zxybazh Mar 5, 2021
4539cff
Fix driver_api.cc
zxybazh Mar 5, 2021
b328525
Fix format
zxybazh Mar 5, 2021
ba427ec
Fix bugs and GetHost function usage
zxybazh Mar 5, 2021
915e3d3
Fix clang format
zxybazh Mar 5, 2021
1a9dcb5
Fix bug
zxybazh Mar 6, 2021
efacf81
Merged main branch, resolve conflicts
zxybazh Mar 6, 2021
606ec71
Modify python tests
zxybazh Mar 7, 2021
71e01d0
Change python unit tests to new target api
zxybazh Mar 7, 2021
95539d9
Fi test_runtime_heterogeneous.py
zxybazh Mar 8, 2021
858d901
Modify tutorials & remove extra print
zxybazh Mar 8, 2021
d99b560
Update more tests to new api
zxybazh Mar 8, 2021
62ec2d3
Refine the tutorial target usage
zxybazh Mar 8, 2021
6916758
change argument name for Target constructor function
zxybazh Mar 8, 2021
a762d7d
Fix target export function
zxybazh Mar 9, 2021
b01f6cc
Fix and validate all tutorial usage
zxybazh Mar 9, 2021
b480bee
Remove unused argument
zxybazh Mar 9, 2021
c17a18e
Fix format
zxybazh Mar 9, 2021
a64efd6
Fix bug in driver/build_module.py for heterogeneous target
zxybazh Mar 9, 2021
fa982a9
Fix bug in driver/build_module.py for heterogeneous target more
zxybazh Mar 9, 2021
33c4057
Fix target host type error
zxybazh Mar 10, 2021
88d2379
Merge branch 'main' of https://github.com/apache/tvm into target
zxybazh Mar 10, 2021
75d0f44
Fix cudnn target host bug
zxybazh Mar 10, 2021
47bcc4c
Fix according to reviews, add helper function in python
zxybazh Mar 13, 2021
5d8201e
Refactor code as helper function
zxybazh Mar 16, 2021
c9e1c9b
Expand helper function
zxybazh Mar 16, 2021
ec664ee
Fix bug add and update python helper function
zxybazh Mar 16, 2021
983108c
Update target hosts
zxybazh Mar 16, 2021
ddfdeb2
Fix format & refresh function
zxybazh Mar 16, 2021
cb206ec
Fix unit test bug
zxybazh Mar 16, 2021
ae4ca68
Fix bug in refreshing host
zxybazh Mar 16, 2021
26a8647
Fix bug
zxybazh Mar 16, 2021
83f290b
Add SetHost function
zxybazh Mar 16, 2021
47b072c
Update export function
zxybazh Mar 16, 2021
bef6fbb
Fix format
zxybazh Mar 17, 2021
6771f2d
Fix export bug in target
zxybazh Mar 17, 2021
4442fba
Fix bug on host referencing
zxybazh Mar 17, 2021
542c927
Addtional tests
zxybazh Mar 17, 2021
8a537b4
Address review issues
zxybazh Mar 18, 2021
6f76c1d
Fix format target.py
zxybazh Mar 18, 2021
f46626f
Fix issues and format
zxybazh Mar 30, 2021
244cc40
Add some 3rd party dependencies
zxybazh Mar 30, 2021
fdfb93a
Merge main branch
zxybazh Mar 30, 2021
7f509bd
Merge branch 'main' into target
zxybazh Mar 30, 2021
3804269
Fix target.h format
zxybazh Mar 30, 2021
dd3787c
Remove redundent import
zxybazh Mar 30, 2021
6e114ca
Fix function name
zxybazh Mar 30, 2021
adec87f
Add parameter name
zxybazh Mar 30, 2021
34f1dac
Merge branch 'main' into target
zxybazh Mar 31, 2021
b71bd1a
Fix new code bug
zxybazh Mar 31, 2021
3a8080e
Fix bug in lowering
zxybazh Mar 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#ifndef TVM_TARGET_TARGET_H_
#define TVM_TARGET_TARGET_H_

#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <tvm/node/node.h>
#include <tvm/support/with.h>
#include <tvm/target/target_kind.h>
Expand All @@ -35,6 +37,7 @@
namespace tvm {

class TargetInternal;
class Target;

/*!
* \brief Compilation target.
Expand All @@ -60,6 +63,10 @@ class TargetNode : public Object {
TVM_DLL const std::string& str() const;
/*! \return Export target to JSON-like configuration */
TVM_DLL Map<String, ObjectRef> Export() const;
/*! \return The Optional<Target> typed target host of the TargetNode */
TVM_DLL Optional<Target> GetHost() const;
/*! \return Set target host of the TargetNode */
TVM_DLL void SetHost(Target);

void VisitAttrs(AttrVisitor* v) {
v->Visit("kind", &kind);
Expand Down Expand Up @@ -168,5 +175,11 @@ class Target : public ObjectRef {
TVM_DLL void ExitWithScope();
};

using TargetsMap = Map<Integer, Target>;
zxybazh marked this conversation as resolved.
Show resolved Hide resolved

TVM_DLL void RefreshHost(Target*, Target*);
TVM_DLL void RefreshHost(TargetsMap*, Target*);
TVM_DLL void RefreshHost(Map<Target, IRModule>*, Target*);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Functions in the header file needs clear documentation. More specifically, we need to document very clearly that those functions are not encouraged for common use :-)
  2. Please list the names of arguments.
  3. Also, please consider a name for the function better indicating they are dedicated to legacy behavior. What about Target::SplitIntoLegacyTargetPair?
  4. Consider moving those functions into static functions of the Target class

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that I cannot move the part of the helper function into the Target class because it requires a Map<Target, ObjectRef> type which cannot be compiled correctly. Therefore, I would put the functions outside the class for now.


} // namespace tvm
#endif // TVM_TARGET_TARGET_H_
8 changes: 4 additions & 4 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from tvm.ir import transform
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
from tvm.contrib import tar, ndk
from tvm.target.target import refresh_host

from . import _ffi_api
from .loop_state import StateObject
Expand Down Expand Up @@ -221,10 +222,10 @@ def recover_measure_input(inp, rebuild_state=False):
from .search_task import SearchTask # lazily import to avoid recursive dependency

task = inp.task
task.target, task.target_host = refresh_host(task.target, task.target_host)
new_task = SearchTask(
workload_key=task.workload_key,
target=task.target,
target_host=task.target_host,
hardware_params=task.hardware_params,
layout_rewrite_option=task.layout_rewrite_option,
task_inputs=list(task.task_input_names),
Expand Down Expand Up @@ -621,10 +622,9 @@ def _timed_func(inp_serialized, build_func, verbose):
filename = os.path.join(dirname, "tmp_func." + build_func.output_format)

try:
task.target, task.target_host = refresh_host(task.target, task.target_host)
with transform.PassContext():
func = build_module.build(
sch, args, target=task.target, target_host=task.target_host
)
func = build_module.build(sch, args, target=task.target)
func.export_library(filename, build_func)
# pylint: disable=broad-except
except Exception:
Expand Down
8 changes: 3 additions & 5 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tvm import autotvm, transform
from tvm.ir.transform import PassContext
from tvm.runtime import convert_to_object
from tvm.target.target import refresh_host
from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
from tvm.tir import Reduce
from tvm.tir import expr as _expr
Expand Down Expand Up @@ -108,10 +109,7 @@ def extract_tasks(
"""
# pylint: disable=import-outside-toplevel

if isinstance(target, str):
target = tvm.target.Target(target)
if isinstance(target_host, str):
target_host = tvm.target.Target(target_host)
target, target_host = refresh_host(target, target_host)

# Run the compiler to collect all TOPI calls during compilation.
env = TracingEnvironment(
Expand All @@ -127,12 +125,12 @@ def extract_tasks(
# create search tasks
tasks = []
weights = []
target, target_host = refresh_host(target, target_host)
for wkl_key, weight in env.wkl_key_to_weight.items():
tasks.append(
SearchTask(
workload_key=wkl_key,
target=target,
target_host=target_host,
hardware_params=hardware_params,
# When auto scheduler is used in end to end network, try to apply layout rewrite
# to improve the overall performance
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tvm.runtime import Object, ndarray

from tvm.driver.build_module import build
from tvm.target import Target
from tvm.target.target import refresh_host
from .measure import LocalBuilder, LocalRunner
from .measure_record import load_best_record
from .workload_registry import make_workload_key
Expand Down Expand Up @@ -393,10 +393,8 @@ def __init__(
compute_dag = ComputeDAG(workload_key)

assert target is not None, "Must specify a target."
if isinstance(target, str):
target = Target(target)
if isinstance(target_host, str):
target_host = Target(target_host)

target, target_host = refresh_host(target, target_host)

if layout_rewrite_option is None:
layout_rewrite_option = LayoutRewriteOption.get_target_default(target)
Expand Down Expand Up @@ -506,6 +504,7 @@ def print_best(self, log_file, print_mode="schedule"):
raise ValueError("Invalid print_mode: %s" % print_mode)

def __getstate__(self):
self.target, self.target_host = refresh_host(self.target, self.target_host)
return {
"compute_dag": self.compute_dag,
"workload_key": self.workload_key,
Expand All @@ -530,12 +529,13 @@ def __setstate__(self, state):
if workload[0] not in WORKLOAD_FUNC_REGISTRY:
register_workload_tensors(state["workload_key"], state["compute_dag"].tensors)

state["target"], state["target_host"] = refresh_host(state["target"], state["target_host"])
self.__init_handle_by_constructor__(
_ffi_api.SearchTask,
state["compute_dag"],
state["workload_key"],
state["target"],
state["target_host"],
state["target"].host,
state["hardware_params"],
state["layout_rewrite_option"],
state["task_input_names"],
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from tvm.autotvm.task import get_config
from tvm.autotvm.record import encode, load_from_file
from tvm.autotvm.measure import MeasureResult, MeasureInput
from tvm.target.target import refresh_host

from ...target import Target
from .utils import (
Expand Down Expand Up @@ -525,9 +526,8 @@ def _callback(_, inputs, results):
continue

records = []
task = autotvm.task.create(
"layout_transform", args=args, target=self._target, target_host=target_host
)
self._target, target_host = refresh_host(self._target, target_host)
task = autotvm.task.create("layout_transform", args=args, target=self._target)
tuner = autotvm.tuner.GridSearchTuner(task)
tuner.tune(n_trial=1, measure_option=measure_option, callbacks=[_log_to_list(records)])
if not isinstance(records[0][1].costs[0], float):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def expr2graph(expr, target_ops, node_dict, node_list):
for node_entry in node_list:
if node_entry["op"] in target_ops:
task_name, args = env.task_collection[task_pos]
task = autotvm.task.create(task_name, args, target="llvm", target_host=None)
task = autotvm.task.create(task_name, args, target="llvm")
node_entry["workloads"] = [task.workload]
node_entry["topi_op"] = [task_name]
task_pos += 1
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from tvm.error import TVMError
from tvm.driver import build
from tvm.contrib import nvcc, ndk, tar
from tvm.target.target import refresh_host

from ..utils import get_const_tuple
from ..env import AutotvmGlobalScope
Expand Down Expand Up @@ -418,6 +419,9 @@ def set_task(self, task):
def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None):
"""Common part for building a configuration"""
target, task, config = measure_input

zxybazh marked this conversation as resolved.
Show resolved Hide resolved
target, task.target_host = refresh_host(target, task.target_host)

with target:
s, args = task.instantiate(config)

Expand Down
9 changes: 7 additions & 2 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import tvm
from tvm.autotvm.task.dispatcher import DispatchContext, FallbackContext
from tvm.target.target import refresh_host
from .task import create
from .topi_integration import TaskExtractEnv

Expand Down Expand Up @@ -89,7 +90,8 @@ def extract_from_program(mod, params, target, target_host=None, ops=None):
task: Array of autotvm.task.Task
collected tasks
"""
return extract_from_multiple_program([mod], [params], target, target_host, ops)
target, target_host = refresh_host(target, target_host)
return extract_from_multiple_program([mod], [params], target, ops=ops)


def extract_from_multiple_program(mods, params, target, target_host=None, ops=None):
Expand Down Expand Up @@ -122,6 +124,9 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No

env = TaskExtractEnv.get()

# merge target and target host
target, target_host = refresh_host(target, target_host)
zxybazh marked this conversation as resolved.
Show resolved Hide resolved

# run compiler to collect all TOPI calls during compilation
env.reset(ops)
with env:
Expand Down Expand Up @@ -152,7 +157,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
tasks = []
for task_name, args in env.get_tasks():
try:
tsk = create(task_name, args, target=target, target_host=target_host)
tsk = create(task_name, args, target=target)
tasks.append(tsk)
except topi.InvalidShapeError:
logger.warning("Invalid shape during AutoTVM task creation")
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tvm.ir import container
from tvm.target import Target
from tvm.te import placeholder, tensor
from tvm.target.target import refresh_host
from tvm.tir import expr


Expand Down Expand Up @@ -175,14 +176,15 @@ def __getstate__(self):
# and restore the function by name when unpickling it.
import cloudpickle # pylint: disable=import-outside-toplevel

self.target, self.target_host = refresh_host(self.target, self.target_host)
return {
"name": self.name,
"args": self.args,
"kwargs": self.kwargs,
"config_space": self.config_space,
"flop": self.flop,
"target": self.target,
"target_host": self.target_host,
"target_host": self.target.host,
"func": cloudpickle.dumps(self.func),
}

Expand All @@ -195,8 +197,7 @@ def __setstate__(self, state):
self.config_space = state["config_space"]
self.func = cloudpickle.loads(state["func"])
self.flop = state["flop"]
self.target = state["target"]
self.target_host = state["target_host"]
self.target, self.target_host = refresh_host(state["target"], state["target_host"])

def __repr__(self):
return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
Expand Down Expand Up @@ -448,6 +449,8 @@ def create(task_name, args, target, target_host=None):
if isinstance(target, str):
target = Target(target)

target, target_host = refresh_host(target, target_host)

# init config space
ret.config_space = ConfigSpace()

Expand All @@ -459,7 +462,7 @@ def create(task_name, args, target, target_host=None):

ret.flop = ret.config_space.flop or compute_flop(sch)
ret.target = target
ret.target_host = target_host
ret.target_host = target.host

return ret

Expand Down
15 changes: 12 additions & 3 deletions python/tvm/contrib/peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
import tvm
from tvm import te
from tvm.target.target import refresh_host
from . import utils
from .. import rpc

Expand Down Expand Up @@ -106,8 +107,10 @@ def measure_bandwidth_sum(
s[y].bind(yi, te.thread_axis("threadIdx.x"))
s[y].unroll(k)

target, target_host = refresh_host(target, target_host)

try:
func = tvm.build(s, [x, y], target, target_host=target_host)
func = tvm.build(s, [x, y], target)

x = tvm.nd.empty((n,), dtype=dtype, ctx=ctx)
y = tvm.nd.empty((n // m,), dtype=dtype, ctx=ctx)
Expand Down Expand Up @@ -153,6 +156,8 @@ def measure_bandwidth_all_types(
"""
max_threads = target.max_num_threads

target, target_host = refresh_host(target, target_host)

result = []
for base_type in ["float"]:
for bits in [32]:
Expand Down Expand Up @@ -229,6 +234,8 @@ def measure_compute_mad(

max_threads = target.max_num_threads

target, target_host = refresh_host(target, target_host)

base_type = str(base_type) + str(bits)
dtype = base_type if lanes == 1 else base_type + "x" + str(lanes)

Expand Down Expand Up @@ -272,7 +279,7 @@ def mad_func(x, y):
s = te.create_schedule(y.op)

try:
func = tvm.build(s, [y], target, target_host=target_host)
func = tvm.build(s, [y], target)
func = _convert_to_remote(func, remote)
time_f = func.time_evaluator(func.entry_name, ctx, number=n_times)
y = tvm.nd.empty((n,), dtype=dtype, ctx=ctx)
Expand Down Expand Up @@ -313,6 +320,8 @@ def measure_compute_all_types(
result: list
a list of (type_name, GFLOPS/GIOPS) pairs
"""
target, target_host = refresh_host(target, target_host)

result = []
for base_type in ["float", "int"]:
for bits in [16, 32, 64]:
Expand Down Expand Up @@ -357,7 +366,7 @@ def measure_peak_all(target, target_host, host, port):
port: int
"""

target = tvm.target.Target(target)
target, target_host = refresh_host(target, target_host)
remote = rpc.connect(host, port)
n_times = 20

Expand Down
10 changes: 7 additions & 3 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tvm.te import tensor
from tvm.te import schedule
from tvm.target import Target
from tvm.target.target import refresh_host


def get_binds(args, compact=False, binds=None):
Expand Down Expand Up @@ -231,8 +232,7 @@ def _build_for_device(input_mod, target, target_host):
mdev : tvm.module
A module that contains device code.
"""
target = Target(target)
target_host = Target(target_host)
target, target_host = refresh_host(target, target_host)
device_type = ndarray.context(target.kind.name, 0).device_type

mod_mixed = input_mod
Expand Down Expand Up @@ -399,8 +399,10 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi
if not isinstance(mod, tvm.IRModule):
raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.")

target_input_mod, target_host = refresh_host(target_input_mod, target_host)

if not target_host:
for tar, _ in target_input_mod.items():
for tar, mod in target_input_mod.items():
tar = Target(tar)
device_type = ndarray.context(tar.kind.name, 0).device_type
if device_type == ndarray.cpu(0).device_type:
Expand All @@ -409,6 +411,8 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"

target_input_mod, target_host = refresh_host(target_input_mod, target_host)

mod_host_all = tvm.IRModule({})

device_modules = []
Expand Down
Loading