Skip to content

Commit

Permalink
added support for tuning microTVM models using meta_schedule.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkatanbaf committed Dec 12, 2022
1 parent 760b10a commit fc04678
Show file tree
Hide file tree
Showing 17 changed files with 652 additions and 11 deletions.
2 changes: 2 additions & 0 deletions include/tvm/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ class Mutator : public runtime::ObjectRef {
TVM_DLL static Map<Mutator, FloatImm, void> DefaultCUDATensorCore();
/*! \brief Create default mutators for Hexagon */
TVM_DLL static Map<Mutator, FloatImm, void> DefaultHexagon();
/*! \brief Create default mutators for Micro */
TVM_DLL static Map<Mutator, FloatImm, void> DefaultMicro();

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode);
};
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ class Postproc : public runtime::ObjectRef {
TVM_DLL static Array<Postproc, void> DefaultCUDATensorCore();
/*! \brief Create default postprocessors for Hexagon */
TVM_DLL static Array<Postproc, void> DefaultHexagon();
/*! \brief Create default postprocessors for Micro */
TVM_DLL static Array<Postproc, void> DefaultMicro();

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
};
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ class ScheduleRule : public runtime::ObjectRef {
TVM_DLL static Array<ScheduleRule, void> DefaultCUDATensorCore();
/*! \brief Create default schedule rules for Hexagon */
TVM_DLL static Array<ScheduleRule, void> DefaultHexagon();
/*! \brief Create default schedule rules for Micro */
TVM_DLL static Array<ScheduleRule, void> DefaultMicro();

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode);
};
Expand Down
84 changes: 84 additions & 0 deletions python/tvm/contrib/micro/meta_schedule/local_builder_micro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Local builder for microTVM projects that compile on the local host"""

import os
import tempfile
from typing import Optional, Dict
from tvm.ir import IRModule
from tvm.runtime import NDArray
from tvm.target import Target
from tvm.meta_schedule.builder import LocalBuilder
from tvm.driver.build_module import OperatorModule
from tvm import micro
from tvm.contrib.tar import tar
from tvm.relay.backend import Runtime
from tvm.driver import build as tvm_build
from tvm.tir.transform import RemoveWeightLayoutRewriteBlock


def get_micro_local_builder():
"""Return micro-compatible Builder for meta schedule."""

def _micro_build(
mod: IRModule, target: Target, _params: Optional[Dict[str, NDArray]]
) -> OperatorModule:
"""Build function for micro targets.
Parameters
----------
mod : IRModule
The IRModule to be built.
target : Target
The target to be built.
_params : Optional[Dict[str, NDArray]]
The parameters to be used for the build. Must be None.
Returns
-------
rt_mod : OperatorModule
The built Module.
"""

# Note: tvm_build assigns "global_symbol" to the name of generated C function
# changing it is necessary for micro targets,
# since the generated projects already include a main function.
prim_func = mod["main"].with_attr("global_symbol", "default_function")
mod = IRModule({"main": prim_func})
runtime = Runtime("crt", {"system-lib": True})
mod = RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=True)(mod)
rt_mod = tvm_build(mod, target=target, runtime=runtime)
return rt_mod

def _micro_export(mod: OperatorModule) -> str:
"""Export function for micro targets.
Parameters
----------
mod : OperatorModule
The Module to be exported.
Returns
-------
artifact_path : str
The path to the exported Module.
"""
artifact_path = os.path.join(tempfile.mkdtemp(), "tvm_tmp_mod." + tar.output_format)
micro.export_model_library_format(mod, artifact_path)
return artifact_path

return LocalBuilder(f_build=_micro_build, f_export=_micro_export)
233 changes: 233 additions & 0 deletions python/tvm/contrib/micro/meta_schedule/rpc_runner_micro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""RPC Runner Micro"""

from contextlib import contextmanager
from typing import Callable, List, Optional
from collections import namedtuple
import signal

from tvm import micro
from tvm import nd
from tvm.contrib.popen_pool import PopenPoolExecutor
from tvm.rpc.server import Server
from tvm.rpc.tracker import Tracker
from tvm.meta_schedule.logging import get_logger
from tvm.meta_schedule.utils import cpu_count, derived_object
from tvm.meta_schedule.runner.config import EvaluatorConfig, RPCConfig
from tvm.meta_schedule.runner import PyRunner, RunnerFuture, RunnerInput
from tvm.meta_schedule.runner.rpc_runner import RPCRunnerFuture
from tvm.meta_schedule.runner.utils import T_ARG_INFO_JSON_OBJ_LIST

logger = get_logger(__name__) # pylint: disable=invalid-name


@derived_object
class RPCRunnerMicro(PyRunner):
"""RPC based runner for tuning micro models."""

def __init__(
self,
platform: str = "crt",
project_options: Optional[dict] = None,
rpc_config: Optional[RPCConfig] = None,
evaluator_config: Optional[EvaluatorConfig] = None,
max_workers: Optional[int] = None,
initializer: Optional[Callable[[], None]] = None,
) -> None:
"""Constructor
Parameters
----------
platform: str
The platform used for project generation.
project_options: dict
The options for the generated micro project.
rpc_config: RPCConfig
The rpc configuration.
evaluator_config: EvaluatorConfig
The evaluator configuration.
max_workers: Optional[int] = None
The maximum number of connections. Defaults to number of logical CPU cores.
initializer: Optional[Callable[[], None]]
The initializer function.
"""
super().__init__()
self.platform = platform
if project_options is None:
project_options = {}
self.project_options = project_options
self.rpc_config = RPCConfig._normalized(rpc_config)
self.evaluator_config = EvaluatorConfig._normalized(evaluator_config)

if max_workers is None:
max_workers = cpu_count(logical=True)
logger.info("RPCRunner: max_workers = %d", max_workers)
self.pool = PopenPoolExecutor(
max_workers=max_workers,
timeout=rpc_config.session_timeout_sec,
initializer=initializer,
)

def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
results: List[RunnerFuture] = []

for runner_input in runner_inputs:
future = RPCRunnerFuture(
future=self.pool.submit(
_worker_func,
self.platform,
self.project_options or {},
self.rpc_config,
self.evaluator_config,
str(runner_input.artifact_path),
str(runner_input.device_type),
tuple(arg_info.as_json() for arg_info in runner_input.args_info),
),
timeout_sec=self.rpc_config.session_timeout_sec,
)
results.append(future) # type: ignore
return results


def _worker_func(
platform: str,
project_options: dict,
rpc_config: RPCConfig,
evaluator_config: EvaluatorConfig,
artifact_path: str,
device_type: str,
args_info: T_ARG_INFO_JSON_OBJ_LIST,
) -> List[float]:

module_loader = micro.AutoTvmModuleLoader(
template_project_dir=micro.get_microtvm_template_projects(platform),
project_options=project_options,
)

remote_kw = {
"device_key": rpc_config.tracker_key,
"host": rpc_config.tracker_host,
"port": rpc_config.tracker_port,
"priority": 0,
"timeout": 100,
}
build_result = namedtuple("BuildResult", ["filename"])(artifact_path)

with module_loader(remote_kw, build_result) as (remote, mod):
dev = remote.device(device_type, 0)
f_prepare = ""
if evaluator_config.enable_cpu_cache_flush:
f_prepare = "cache_flush_cpu_non_first_arg"
time_f = mod.time_evaluator(
mod.entry_name,
dev,
number=evaluator_config.number,
repeat=evaluator_config.repeat,
min_repeat_ms=evaluator_config.min_repeat_ms,
f_preproc=f_prepare,
)

random_fill = remote.get_function("tvm.contrib.random.random_fill")
args = [nd.empty(x[2], x[1], dev) for x in args_info]
for arg in args:
random_fill(arg)
dev.sync()

costs = time_f(*args).results
return costs


@contextmanager
def get_rpc_runner_micro(
platform,
options,
rpc_config: RPCConfig = None,
evaluator_config: EvaluatorConfig = None,
session_timeout_sec=300,
):
"""Parameters
----------
platform: str
The platform used for project generation.
project_options: dict
The options for the generated micro project.
rpc_config: RPCConfig
The rpc configuration.
evaluator_config: EvaluatorConfig
The evaluator configuration.
session_timeout_sec: int
The session timeout. if the number of candidates sent to runner is larger
than the runner workers, increase the timeout.
"""
if rpc_config is None:
tracker_host = "127.0.0.1"
tracker_port = 9000
tracker_key = "$local$device$%d" % tracker_port
rpc_config = RPCConfig(
tracker_host=tracker_host,
tracker_port=tracker_port,
tracker_key=tracker_key,
session_priority=0,
session_timeout_sec=session_timeout_sec,
)
tracker_port_end = rpc_config.tracker_port + 1000

if evaluator_config is None:
evaluator_config = EvaluatorConfig(
number=3,
repeat=1,
min_repeat_ms=100,
enable_cpu_cache_flush=False,
)

tracker = Tracker(
port=rpc_config.tracker_port,
port_end=tracker_port_end,
silent=True,
reuse_addr=True,
timeout=60,
)
server = Server(
port=rpc_config.tracker_port,
port_end=tracker_port_end,
key=rpc_config.tracker_key,
silent=True,
tracker_addr=(rpc_config.tracker_host, rpc_config.tracker_port),
reuse_addr=True,
timeout=60,
)

def terminate():
tracker.terminate()
server.terminate()

def handle_SIGINT(signal, frame):
terminate()
raise KeyboardInterrupt("Received SIGINT")

signal.signal(signal.SIGINT, handle_SIGINT)

try:
yield RPCRunnerMicro(
platform=platform,
project_options=options,
rpc_config=rpc_config,
evaluator_config=evaluator_config,
)
finally:
terminate()
Loading

0 comments on commit fc04678

Please sign in to comment.