Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[microTVM] tuning on micro targets with meta-schedule #13514

Merged
merged 7 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ class Mutator : public runtime::ObjectRef {
TVM_DLL static Map<Mutator, FloatImm, void> DefaultCUDATensorCore();
/*! \brief Create default mutators for Hexagon */
TVM_DLL static Map<Mutator, FloatImm, void> DefaultHexagon();
/*! \brief Create default mutators for Micro */
TVM_DLL static Map<Mutator, FloatImm, void> DefaultMicro();

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode);
};
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ class Postproc : public runtime::ObjectRef {
TVM_DLL static Array<Postproc, void> DefaultCUDATensorCore();
/*! \brief Create default postprocessors for Hexagon */
TVM_DLL static Array<Postproc, void> DefaultHexagon();
/*! \brief Create default postprocessors for Micro */
TVM_DLL static Array<Postproc, void> DefaultMicro();

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
};
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ class ScheduleRule : public runtime::ObjectRef {
TVM_DLL static Array<ScheduleRule, void> DefaultCUDATensorCore();
/*! \brief Create default schedule rules for Hexagon */
TVM_DLL static Array<ScheduleRule, void> DefaultHexagon();
/*! \brief Create default schedule rules for Micro */
TVM_DLL static Array<ScheduleRule, void> DefaultMicro();

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode);
};
Expand Down
84 changes: 84 additions & 0 deletions python/tvm/contrib/micro/meta_schedule/local_builder_micro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Local builder for microTVM projects that compile on the local host"""

import os
import tempfile
from typing import Optional, Dict
from tvm.ir import IRModule
from tvm.runtime import NDArray
from tvm.target import Target
from tvm.meta_schedule.builder import LocalBuilder
from tvm.driver.build_module import OperatorModule
from tvm import micro
from tvm.contrib.tar import tar
from tvm.relay.backend import Runtime
from tvm.driver import build as tvm_build
from tvm.tir.transform import RemoveWeightLayoutRewriteBlock


def get_local_builder_micro():
"""Return micro-compatible Builder for meta schedule."""

def _micro_build(
mod: IRModule, target: Target, _params: Optional[Dict[str, NDArray]]
) -> OperatorModule:
"""Build function for micro targets.

Parameters
----------
mod : IRModule
The IRModule to be built.
target : Target
The target to be built.
_params : Optional[Dict[str, NDArray]]
The parameters to be used for the build. Must be None.

Returns
-------
rt_mod : OperatorModule
The built Module.
"""

# Note: tvm_build assigns "global_symbol" to the name of generated C function
# changing it is necessary for micro targets,
# since the generated projects already include a main function.
prim_func = mod["main"].with_attr("global_symbol", "default_function")
mkatanbaf marked this conversation as resolved.
Show resolved Hide resolved
mod = IRModule({"main": prim_func})
runtime = Runtime("crt", {"system-lib": True})
mod = RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=True)(mod)
rt_mod = tvm_build(mod, target=target, runtime=runtime)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to confirm--are the relay.build changes needed in this PR? if not, could we remove them until we figure out how to wrap a TIR function in a Relay function?

Copy link
Contributor Author

@mkatanbaf mkatanbaf Jan 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, they are not. We continue to use tvm_build in tuning. The changes in relay_integration.py are needed for compiling the relay program using the MetaSchedule tuning database.

return rt_mod

def _micro_export(mod: OperatorModule) -> str:
"""Export function for micro targets.

Parameters
----------
mod : OperatorModule
The Module to be exported.

Returns
-------
artifact_path : str
The path to the exported Module.
"""
artifact_path = os.path.join(tempfile.mkdtemp(), "tvm_tmp_mod." + tar.output_format)
micro.export_model_library_format(mod, artifact_path)
return artifact_path

return LocalBuilder(f_build=_micro_build, f_export=_micro_export)
233 changes: 233 additions & 0 deletions python/tvm/contrib/micro/meta_schedule/rpc_runner_micro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""RPC Runner Micro"""

from contextlib import contextmanager
from typing import Callable, List, Optional
from collections import namedtuple
import signal

from tvm import micro
from tvm import nd
from tvm.contrib.popen_pool import PopenPoolExecutor
from tvm.rpc.server import Server
from tvm.rpc.tracker import Tracker
from tvm.meta_schedule.logging import get_logger
from tvm.meta_schedule.utils import cpu_count, derived_object
from tvm.meta_schedule.runner.config import EvaluatorConfig, RPCConfig
from tvm.meta_schedule.runner import PyRunner, RunnerFuture, RunnerInput
from tvm.meta_schedule.runner.rpc_runner import RPCRunnerFuture
from tvm.meta_schedule.runner.utils import T_ARG_INFO_JSON_OBJ_LIST

logger = get_logger(__name__) # pylint: disable=invalid-name


@derived_object
class RPCRunnerMicro(PyRunner):
"""RPC based runner for tuning micro models."""

def __init__(
self,
platform: str = "crt",
project_options: Optional[dict] = None,
rpc_config: Optional[RPCConfig] = None,
evaluator_config: Optional[EvaluatorConfig] = None,
max_workers: Optional[int] = None,
initializer: Optional[Callable[[], None]] = None,
) -> None:
"""Constructor

Parameters
----------
platform: str
The platform used for project generation.
project_options: dict
The options for the generated micro project.
rpc_config: RPCConfig
The rpc configuration.
evaluator_config: EvaluatorConfig
The evaluator configuration.
max_workers: Optional[int] = None
The maximum number of connections. Defaults to number of logical CPU cores.
initializer: Optional[Callable[[], None]]
The initializer function.
"""
super().__init__()
self.platform = platform
if project_options is None:
project_options = {}
self.project_options = project_options
self.rpc_config = RPCConfig._normalized(rpc_config)
self.evaluator_config = EvaluatorConfig._normalized(evaluator_config)

if max_workers is None:
max_workers = cpu_count(logical=True)
logger.info("RPCRunner: max_workers = %d", max_workers)
self.pool = PopenPoolExecutor(
max_workers=max_workers,
timeout=rpc_config.session_timeout_sec,
initializer=initializer,
)

def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
results: List[RunnerFuture] = []

for runner_input in runner_inputs:
future = RPCRunnerFuture(
future=self.pool.submit(
_worker_func,
self.platform,
self.project_options or {},
self.rpc_config,
self.evaluator_config,
str(runner_input.artifact_path),
str(runner_input.device_type),
tuple(arg_info.as_json() for arg_info in runner_input.args_info),
),
timeout_sec=self.rpc_config.session_timeout_sec,
)
results.append(future) # type: ignore
return results


def _worker_func(
platform: str,
project_options: dict,
rpc_config: RPCConfig,
evaluator_config: EvaluatorConfig,
artifact_path: str,
device_type: str,
args_info: T_ARG_INFO_JSON_OBJ_LIST,
) -> List[float]:

module_loader = micro.AutoTvmModuleLoader(
template_project_dir=micro.get_microtvm_template_projects(platform),
project_options=project_options,
)

remote_kw = {
"device_key": rpc_config.tracker_key,
"host": rpc_config.tracker_host,
"port": rpc_config.tracker_port,
"priority": 0,
"timeout": 100,
}
build_result = namedtuple("BuildResult", ["filename"])(artifact_path)

with module_loader(remote_kw, build_result) as (remote, mod):
dev = remote.device(device_type, 0)
f_prepare = ""
if evaluator_config.enable_cpu_cache_flush:
f_prepare = "cache_flush_cpu_non_first_arg"
time_f = mod.time_evaluator(
mod.entry_name,
dev,
number=evaluator_config.number,
repeat=evaluator_config.repeat,
min_repeat_ms=evaluator_config.min_repeat_ms,
f_preproc=f_prepare,
)

random_fill = remote.get_function("tvm.contrib.random.random_fill")
args = [nd.empty(x[2], x[1], dev) for x in args_info]
for arg in args:
random_fill(arg)
dev.sync()

costs = time_f(*args).results
return costs


@contextmanager
def get_rpc_runner_micro(
platform,
options,
rpc_config: RPCConfig = None,
evaluator_config: EvaluatorConfig = None,
session_timeout_sec=300,
):
"""Parameters
----------
platform: str
The platform used for project generation.
project_options: dict
The options for the generated micro project.
rpc_config: RPCConfig
The rpc configuration.
evaluator_config: EvaluatorConfig
The evaluator configuration.
session_timeout_sec: int
The session timeout. if the number of candidates sent to runner is larger
than the runner workers, increase the timeout.
"""
if rpc_config is None:
tracker_host = "127.0.0.1"
tracker_port = 9000
tracker_key = "$local$device$%d" % tracker_port
rpc_config = RPCConfig(
tracker_host=tracker_host,
tracker_port=tracker_port,
tracker_key=tracker_key,
session_priority=0,
session_timeout_sec=session_timeout_sec,
)
tracker_port_end = rpc_config.tracker_port + 1000

if evaluator_config is None:
evaluator_config = EvaluatorConfig(
number=3,
repeat=1,
min_repeat_ms=100,
enable_cpu_cache_flush=False,
)

tracker = Tracker(
port=rpc_config.tracker_port,
port_end=tracker_port_end,
silent=True,
reuse_addr=True,
timeout=60,
)
server = Server(
port=rpc_config.tracker_port,
port_end=tracker_port_end,
key=rpc_config.tracker_key,
silent=True,
tracker_addr=(rpc_config.tracker_host, rpc_config.tracker_port),
reuse_addr=True,
timeout=60,
)

def terminate():
tracker.terminate()
server.terminate()

def handle_SIGINT(signal, frame):
terminate()
raise KeyboardInterrupt("Received SIGINT")

signal.signal(signal.SIGINT, handle_SIGINT)

try:
yield RPCRunnerMicro(
platform=platform,
project_options=options,
rpc_config=rpc_config,
evaluator_config=evaluator_config,
)
finally:
terminate()
Loading