diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index 4095d6ca03972..498b2797ada58 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -139,6 +139,8 @@ class Mutator : public runtime::ObjectRef { TVM_DLL static Map DefaultCUDATensorCore(); /*! \brief Create default mutators for Hexagon */ TVM_DLL static Map DefaultHexagon(); + /*! \brief Create default mutators for Micro */ + TVM_DLL static Map DefaultMicro(); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode); }; diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 76f8d71ad65bc..4c7d66177cb8e 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -165,6 +165,8 @@ class Postproc : public runtime::ObjectRef { TVM_DLL static Array DefaultCUDATensorCore(); /*! \brief Create default postprocessors for Hexagon */ TVM_DLL static Array DefaultHexagon(); + /*! \brief Create default postprocessors for Micro */ + TVM_DLL static Array DefaultMicro(); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode); }; diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 879dd076a8b51..16202e18bf95d 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -298,6 +298,8 @@ class ScheduleRule : public runtime::ObjectRef { TVM_DLL static Array DefaultCUDATensorCore(); /*! \brief Create default schedule rules for Hexagon */ TVM_DLL static Array DefaultHexagon(); + /*! \brief Create default schedule rules for Micro */ + TVM_DLL static Array DefaultMicro(); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode); }; diff --git a/python/tvm/contrib/micro/meta_schedule/local_builder_micro.py b/python/tvm/contrib/micro/meta_schedule/local_builder_micro.py new file mode 100644 index 0000000000000..df1e1fb750645 --- /dev/null +++ b/python/tvm/contrib/micro/meta_schedule/local_builder_micro.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Local builder for microTVM projects that compile on the local host""" + +import os +import tempfile +from typing import Optional, Dict +from tvm.ir import IRModule +from tvm.runtime import NDArray +from tvm.target import Target +from tvm.meta_schedule.builder import LocalBuilder +from tvm.driver.build_module import OperatorModule +from tvm import micro +from tvm.contrib.tar import tar +from tvm.relay.backend import Runtime +from tvm.driver import build as tvm_build +from tvm.tir.transform import RemoveWeightLayoutRewriteBlock + + +def get_micro_local_builder(): + """Return micro-compatible Builder for meta schedule.""" + + def _micro_build( + mod: IRModule, target: Target, _params: Optional[Dict[str, NDArray]] + ) -> OperatorModule: + """Build function for micro targets. + + Parameters + ---------- + mod : IRModule + The IRModule to be built. + target : Target + The target to be built. + _params : Optional[Dict[str, NDArray]] + The parameters to be used for the build. Must be None. + + Returns + ------- + rt_mod : OperatorModule + The built Module. + """ + + # Note: tvm_build assigns "global_symbol" to the name of generated C function + # changing it is necessary for micro targets, + # since the generated projects already include a main function. + prim_func = mod["main"].with_attr("global_symbol", "default_function") + mod = IRModule({"main": prim_func}) + runtime = Runtime("crt", {"system-lib": True}) + mod = RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=True)(mod) + rt_mod = tvm_build(mod, target=target, runtime=runtime) + return rt_mod + + def _micro_export(mod: OperatorModule) -> str: + """Export function for micro targets. + + Parameters + ---------- + mod : OperatorModule + The Module to be exported. + + Returns + ------- + artifact_path : str + The path to the exported Module. + """ + artifact_path = os.path.join(tempfile.mkdtemp(), "tvm_tmp_mod." + tar.output_format) + micro.export_model_library_format(mod, artifact_path) + return artifact_path + + return LocalBuilder(f_build=_micro_build, f_export=_micro_export) diff --git a/python/tvm/contrib/micro/meta_schedule/rpc_runner_micro.py b/python/tvm/contrib/micro/meta_schedule/rpc_runner_micro.py new file mode 100644 index 0000000000000..e4c08351841d5 --- /dev/null +++ b/python/tvm/contrib/micro/meta_schedule/rpc_runner_micro.py @@ -0,0 +1,233 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""RPC Runner Micro""" + +from contextlib import contextmanager +from typing import Callable, List, Optional +from collections import namedtuple +import signal + +from tvm import micro +from tvm import nd +from tvm.contrib.popen_pool import PopenPoolExecutor +from tvm.rpc.server import Server +from tvm.rpc.tracker import Tracker +from tvm.meta_schedule.logging import get_logger +from tvm.meta_schedule.utils import cpu_count, derived_object +from tvm.meta_schedule.runner.config import EvaluatorConfig, RPCConfig +from tvm.meta_schedule.runner import PyRunner, RunnerFuture, RunnerInput +from tvm.meta_schedule.runner.rpc_runner import RPCRunnerFuture +from tvm.meta_schedule.runner.utils import T_ARG_INFO_JSON_OBJ_LIST + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +@derived_object +class RPCRunnerMicro(PyRunner): + """RPC based runner for tuning micro models.""" + + def __init__( + self, + platform: str = "crt", + project_options: Optional[dict] = None, + rpc_config: Optional[RPCConfig] = None, + evaluator_config: Optional[EvaluatorConfig] = None, + max_workers: Optional[int] = None, + initializer: Optional[Callable[[], None]] = None, + ) -> None: + """Constructor + + Parameters + ---------- + platform: str + The platform used for project generation. + project_options: dict + The options for the generated micro project. + rpc_config: RPCConfig + The rpc configuration. + evaluator_config: EvaluatorConfig + The evaluator configuration. + max_workers: Optional[int] = None + The maximum number of connections. Defaults to number of logical CPU cores. + initializer: Optional[Callable[[], None]] + The initializer function. + """ + super().__init__() + self.platform = platform + if project_options is None: + project_options = {} + self.project_options = project_options + self.rpc_config = RPCConfig._normalized(rpc_config) + self.evaluator_config = EvaluatorConfig._normalized(evaluator_config) + + if max_workers is None: + max_workers = cpu_count(logical=True) + logger.info("RPCRunner: max_workers = %d", max_workers) + self.pool = PopenPoolExecutor( + max_workers=max_workers, + timeout=rpc_config.session_timeout_sec, + initializer=initializer, + ) + + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + results: List[RunnerFuture] = [] + + for runner_input in runner_inputs: + future = RPCRunnerFuture( + future=self.pool.submit( + _worker_func, + self.platform, + self.project_options or {}, + self.rpc_config, + self.evaluator_config, + str(runner_input.artifact_path), + str(runner_input.device_type), + tuple(arg_info.as_json() for arg_info in runner_input.args_info), + ), + timeout_sec=self.rpc_config.session_timeout_sec, + ) + results.append(future) # type: ignore + return results + + +def _worker_func( + platform: str, + project_options: dict, + rpc_config: RPCConfig, + evaluator_config: EvaluatorConfig, + artifact_path: str, + device_type: str, + args_info: T_ARG_INFO_JSON_OBJ_LIST, +) -> List[float]: + + module_loader = micro.AutoTvmModuleLoader( + template_project_dir=micro.get_microtvm_template_projects(platform), + project_options=project_options, + ) + + remote_kw = { + "device_key": rpc_config.tracker_key, + "host": rpc_config.tracker_host, + "port": rpc_config.tracker_port, + "priority": 0, + "timeout": 100, + } + build_result = namedtuple("BuildResult", ["filename"])(artifact_path) + + with module_loader(remote_kw, build_result) as (remote, mod): + dev = remote.device(device_type, 0) + f_prepare = "" + if evaluator_config.enable_cpu_cache_flush: + f_prepare = "cache_flush_cpu_non_first_arg" + time_f = mod.time_evaluator( + mod.entry_name, + dev, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc=f_prepare, + ) + + random_fill = remote.get_function("tvm.contrib.random.random_fill") + args = [nd.empty(x[2], x[1], dev) for x in args_info] + for arg in args: + random_fill(arg) + dev.sync() + + costs = time_f(*args).results + return costs + + +@contextmanager +def get_rpc_runner_micro( + platform, + options, + rpc_config: RPCConfig = None, + evaluator_config: EvaluatorConfig = None, + session_timeout_sec=300, +): + """Parameters + ---------- + platform: str + The platform used for project generation. + project_options: dict + The options for the generated micro project. + rpc_config: RPCConfig + The rpc configuration. + evaluator_config: EvaluatorConfig + The evaluator configuration. + session_timeout_sec: int + The session timeout. if the number of candidates sent to runner is larger + than the runner workers, increase the timeout. + """ + if rpc_config is None: + tracker_host = "127.0.0.1" + tracker_port = 9000 + tracker_key = "$local$device$%d" % tracker_port + rpc_config = RPCConfig( + tracker_host=tracker_host, + tracker_port=tracker_port, + tracker_key=tracker_key, + session_priority=0, + session_timeout_sec=session_timeout_sec, + ) + tracker_port_end = rpc_config.tracker_port + 1000 + + if evaluator_config is None: + evaluator_config = EvaluatorConfig( + number=3, + repeat=1, + min_repeat_ms=100, + enable_cpu_cache_flush=False, + ) + + tracker = Tracker( + port=rpc_config.tracker_port, + port_end=tracker_port_end, + silent=True, + reuse_addr=True, + timeout=60, + ) + server = Server( + port=rpc_config.tracker_port, + port_end=tracker_port_end, + key=rpc_config.tracker_key, + silent=True, + tracker_addr=(rpc_config.tracker_host, rpc_config.tracker_port), + reuse_addr=True, + timeout=60, + ) + + def terminate(): + tracker.terminate() + server.terminate() + + def handle_SIGINT(signal, frame): + terminate() + raise KeyboardInterrupt("Received SIGINT") + + signal.signal(signal.SIGINT, handle_SIGINT) + + try: + yield RPCRunnerMicro( + platform=platform, + project_options=options, + rpc_config=rpc_config, + evaluator_config=evaluator_config, + ) + finally: + terminate() diff --git a/python/tvm/contrib/micro/meta_schedule/test_autotune_ms.py b/python/tvm/contrib/micro/meta_schedule/test_autotune_ms.py new file mode 100644 index 0000000000000..310a3f4ffcc81 --- /dev/null +++ b/python/tvm/contrib/micro/meta_schedule/test_autotune_ms.py @@ -0,0 +1,181 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest +from types import MappingProxyType +import pathlib +import json + +import tvm +from tvm import relay +from tvm.relay.backend import Executor +from tvm.contrib import graph_executor, utils +from tvm import meta_schedule as ms +from tvm.contrib.micro.meta_schedule.local_builder_micro import get_micro_local_builder +from tvm.contrib.micro.meta_schedule.rpc_runner_micro import get_rpc_runner_micro + + +def get_module(): + data_shape = (1, 3, 16, 16) + weight_shape = (8, 3, 5, 5) + data = relay.var("data", relay.TensorType(data_shape, "float32")) + weight = relay.var("weight", relay.TensorType(weight_shape, "float32")) + y = relay.nn.conv2d( + data, + weight, + padding=(2, 2), + kernel_size=(5, 5), + kernel_layout="OIHW", + out_dtype="float32", + ) + f = relay.Function([data, weight], y) + mod = tvm.IRModule.from_expr(f) + mod = relay.transform.InferType()(mod) + + weight_sample = np.random.rand( + weight_shape[0], weight_shape[1], weight_shape[2], weight_shape[3] + ).astype("float32") + params = {mod["main"].params[1].name_hint: weight_sample} + + model_info = { + "in_tensor": "data", + "in_shape": data_shape, + "in_dtype": "float32", + } + + return mod, params, model_info + + +@tvm.testing.requires_micro +@pytest.mark.parametrize( + "platform, options", + [ + pytest.param("crt", None), + pytest.param( + "zephyr", + { + "board": "qemu_x86", + "project_type": "host_driven", + }, + ), + ], +) +def test_micro_tuning_with_meta_schedule(platform, options): + if platform == "crt": + target = tvm.target.target.micro(model="host", options="-num-cores=1") + else: + boards_file = ( + pathlib.Path(tvm.micro.get_microtvm_template_projects("zephyr")) / "boards.json" + ) + with open(boards_file) as f: + boards = json.load(f) + target = tvm.target.target.micro( + model=boards[options["board"]]["model"], options="-mcpu=cortex-m4 -num-cores=1" + ) + + work_dir = utils.tempdir() + mod, params, model_info = get_module() + input_name = model_info["in_tensor"] + input_shape = model_info["in_shape"] + input_dtype = model_info["in_dtype"] + data_sample = np.random.rand(*input_shape).astype(input_dtype) + + runtime = relay.backend.Runtime("crt", {"system-lib": True}) + executor = Executor("aot", {"link-params": True}) + # This line is necessary for link-params to take effect during + # task extraction and relay.build(...). + mod = mod.with_attr("executor", executor) + + builder = get_micro_local_builder() + with get_rpc_runner_micro( + platform=platform, options=options, session_timeout_sec=120 + ) as runner: + with ms.Profiler() as profiler: + db: ms.Database = ms.relay_integration.tune_relay( + mod=mod, + params=params, + target=target, + builder=builder, + runner=runner, + strategy="evolutionary", + num_trials_per_iter=2, + max_trials_per_task=10, + max_trials_global=100, + work_dir=str(work_dir), + module_equality="ignore-ndarray", + ) + + # Build model using meta_schedule logs + ms_mod: tvm.runtime.Module = ms.relay_integration.compile_relay( + database=db, + mod=mod, + target=target, + params=params, + pass_config=MappingProxyType( + { + "relay.backend.use_meta_schedule": True, + "relay.backend.tir_converter": "default", + "tir.disable_vectorize": True, + } + ), + executor=executor, + runtime=runtime, + ) + + project = tvm.micro.generate_project( + str(tvm.micro.get_microtvm_template_projects(platform)), + ms_mod, + str(work_dir / "project"), + options=options, + ) + project.build() + project.flash() + with tvm.micro.Session(project.transport()) as session: + aot_executor = tvm.runtime.executor.aot_executor.AotModule( + session.create_aot_executor() + ) + aot_executor.get_input(0).copyfrom(data_sample) + result = aot_executor.module.time_evaluator("run", session.device, number=3)() + output = aot_executor.get_output(0).numpy() + + print(profiler.table()) + + # Build reference model (without tuning) + dev = tvm.cpu() + target = tvm.target.target.micro(model="host", options="-num-cores=1") + with tvm.transform.PassContext( + opt_level=3, config={"tir.disable_vectorize": True}, disabled_pass=["AlterOpLayout"] + ): + ref_mod = relay.build( + mod, + target=target, + params=params, + runtime=runtime, + ) + ref_mod.export_library(work_dir / "compiled_lib2.so") + mod2: tvm.runtime.Module = tvm.runtime.load_module(work_dir / "compiled_lib2.so") + graph_mod = graph_executor.GraphModule(mod2["default"](dev)) + graph_mod.set_input(input_name, data_sample) + graph_mod.run() + ref_output = graph_mod.get_output(0).numpy() + + assert np.allclose(output, ref_output, rtol=1e-4, atol=2e-4), "FAILED" + work_dir.remove() + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index 0b8705aafea9f..08f35ce3d3a15 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -73,12 +73,14 @@ def _normalize_params( params: Optional[Dict[str, NDArray]], pass_config: Mapping[str, Any], executor: Optional["relay.backend.Executor"], + runtime: Optional["relay.backend.Runtime"], ) -> Tuple[ IRModule, Target, Dict[str, NDArray], Dict[str, Any], Optional["relay.backend.Executor"], + Optional["relay.backend.Runtime"], ]: from tvm import relay # pylint: disable=import-outside-toplevel @@ -97,13 +99,16 @@ def _normalize_params( if executor is None: executor = relay.backend.Executor("graph") + if runtime is None: + runtime = relay.backend.Runtime("cpp") + if mod.get_attr("executor") is None: mod = mod.with_attr("executor", executor) else: executor = mod.get_attr("executor") pass_config = dict(pass_config) - return mod, target, relay_params, pass_config, executor + return mod, target, relay_params, pass_config, executor, runtime def extract_tasks( @@ -119,6 +124,7 @@ def extract_tasks( } ), executor: Optional["relay.backend.Executor"] = None, + runtime: Optional["relay.backend.Runtime"] = None, module_equality: str = "structural", ) -> List[ExtractedTask]: """Extract tuning tasks from a relay program. @@ -137,6 +143,8 @@ def extract_tasks( The pass configuration executor : Optional[relay.backend.Executor] The executor to use + runtime : Optional[relay.backend.Runtime] + The runtime to use module_equality : Optional[str] A string to specify the module equality testing and hashing method. It must be one of the followings: @@ -157,8 +165,13 @@ def extract_tasks( from tvm import autotvm # pylint: enable=import-outside-toplevel - mod, target, params, pass_config, _ = _normalize_params( - mod, target, params, pass_config, executor + mod, target, params, pass_config, _ex, _rt = _normalize_params( + mod, + target, + params, + pass_config, + executor, + runtime, ) if target.kind.name != "cuda" and isinstance( autotvm.DispatchContext.current, autotvm.FallbackContext @@ -345,6 +358,7 @@ def compile_relay( } ), executor: Optional["relay.backend.Executor"] = None, + runtime: Optional["relay.backend.Runtime"] = None, ): """Compile a relay program with a MetaSchedule database. @@ -368,6 +382,8 @@ def compile_relay( The pass configuration executor : Optional[relay.backend.Executor] The executor to use in relay.build. It is not supported by RelayVM. + runtime : Optional[relay.backend.Runtime] + The runtime to use in relay.build. It is not supported by RelayVM. Returns ------- @@ -378,8 +394,8 @@ def compile_relay( from tvm import relay # pylint: enable=import-outside-toplevel - mod, target, params, pass_config, executor = _normalize_params( - mod, target, params, pass_config, executor + mod, target, params, pass_config, executor, runtime = _normalize_params( + mod, target, params, pass_config, executor, runtime ) pass_config.setdefault("relay.backend.use_meta_schedule_dispatch", True) with Profiler.timeit("PostTuningCompilation"): @@ -389,7 +405,9 @@ def compile_relay( config=pass_config, ): if backend == "graph": - return relay.build(mod, target=target, params=params, executor=executor) + return relay.build( + mod, target=target, params=params, executor=executor, runtime=runtime + ) elif backend == "vm": return relay.vm.compile(mod, target=target, params=params) else: diff --git a/python/tvm/relay/backend/executor.py b/python/tvm/relay/backend/executor.py index ac5e5bf1f8293..854473f662c0b 100644 --- a/python/tvm/relay/backend/executor.py +++ b/python/tvm/relay/backend/executor.py @@ -33,15 +33,28 @@ def __init__(self, name, options=None) -> None: if options is None: options = {} self.__init_handle_by_constructor__(_backend.CreateExecutor, name, options) + self._init_wrapper() + + # Note: sometimes the _attrs field is not properly populated, + # most likely since __new__ is called instead of __init__ in tvm/_ffi/_ctypes/object.py + def _init_wrapper(self): self._attrs = _backend.GetExecutorAttrs(self) + self._init_wrapper_called = True + + def _check_init_wrapper(self): + if not (hasattr(self, "_init_wrapper_called") and self._init_wrapper_called): + self._init_wrapper() def __contains__(self, name): + self._check_init_wrapper() return name in self._attrs def __getitem__(self, name): + self._check_init_wrapper() return self._attrs[name] def __eq__(self, other): + self._check_init_wrapper() return str(other) == str(self) and dict(other._attrs) == dict(self._attrs) @staticmethod diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 7932e98aa20cc..7563d29aae94f 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -319,6 +319,8 @@ def __init__( load_library=None, custom_addr=None, silent=False, + reuse_addr=False, + timeout=None, ): # start update @@ -332,6 +334,10 @@ def __init__( if not is_proxy: sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) + if reuse_addr: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if timeout is not None: + sock.settimeout(timeout) self.port = None for my_port in range(port, port_end): try: @@ -371,6 +377,8 @@ def _popen_start_rpc_server( silent=False, no_fork=False, server_init_callback=None, + reuse_addr=False, + timeout=None, ): if no_fork: multiprocessing.set_start_method("spawn") @@ -382,7 +390,17 @@ def _popen_start_rpc_server( # Popen worker to run on a separate process. # Create and start the server in a different thread state = PopenRPCServerState( - host, port, port_end, is_proxy, tracker_addr, key, load_library, custom_addr, silent + host, + port, + port_end, + is_proxy, + tracker_addr, + key, + load_library, + custom_addr, + silent, + reuse_addr, + timeout, ) PopenRPCServerState.current = state # returns the port so that the main can get the port number. @@ -434,6 +452,12 @@ class Server(object): server_init_callback: Callable, optional Additional initialization function when starting the server. + reuse_addr: bool, optional + Allows the kernel to reuse a local socket in TIME_WAIT state. + + timeout: float, optional + set a timeout for all operations on the socket + Note ---- The RPC server only sees functions in the tvm namespace. @@ -464,6 +488,8 @@ def __init__( silent=False, no_fork=False, server_init_callback=None, + reuse_addr=False, + timeout=None, ): try: if _ffi_api.ServerLoop is None: @@ -486,6 +512,8 @@ def __init__( silent, no_fork, server_init_callback, + reuse_addr, + timeout, ], ) # receive the port diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py index e65ed4a012f09..9b0edbe6c24f3 100644 --- a/python/tvm/rpc/tracker.py +++ b/python/tvm/rpc/tracker.py @@ -387,11 +387,17 @@ class PopenTrackerServerState(object): current = None - def __init__(self, host, port=9190, port_end=9199, silent=False): + def __init__( + self, host, port=9190, port_end=9199, silent=False, reuse_addr=False, timeout=None + ): if silent: logger.setLevel(logging.WARN) sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) + if reuse_addr: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if timeout is not None: + sock.settimeout(timeout) self.port = None self.stop_key = base.random_key("tracker") for my_port in range(port, port_end): @@ -412,11 +418,13 @@ def __init__(self, host, port=9190, port_end=9199, silent=False): self.host = host -def _popen_start_tracker_server(host, port=9190, port_end=9199, silent=False): +def _popen_start_tracker_server( + host, port=9190, port_end=9199, silent=False, reuse_addr=False, timeout=None +): # This is a function that will be sent to the # Popen worker to run on a separate process. # Create and start the server in a different thread - state = PopenTrackerServerState(host, port, port_end, silent) + state = PopenTrackerServerState(host, port, port_end, silent, reuse_addr, timeout) PopenTrackerServerState.current = state # returns the port so that the main can get the port number. return (state.port, state.stop_key) @@ -440,9 +448,18 @@ class Tracker(object): silent: bool, optional Whether run in silent mode + + reuse_addr: bool, optional + Allows the kernel to reuse a local socket in TIME_WAIT state. + + timeout: float, optional + set a timeout for all operations on the socket + """ - def __init__(self, host="0.0.0.0", port=9190, port_end=9199, silent=False): + def __init__( + self, host="0.0.0.0", port=9190, port_end=9199, silent=False, reuse_addr=False, timeout=None + ): if silent: logger.setLevel(logging.WARN) self.proc = PopenWorker() @@ -454,6 +471,8 @@ def __init__(self, host="0.0.0.0", port=9190, port_end=9199, silent=False): port, port_end, silent, + reuse_addr, + timeout, ], ) # receive the port diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc index 8f3d14b6c4666..d5b3c360b50b3 100644 --- a/src/meta_schedule/mutator/mutator.cc +++ b/src/meta_schedule/mutator/mutator.cc @@ -78,6 +78,14 @@ Map Mutator::DefaultHexagon() { {Mutator::MutateParallel(/*max_jobs_per_core=*/16), FloatImm(DataType::Float(64), 0.02)}}; } +Map Mutator::DefaultMicro() { + return Map{ + {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)}, + {Mutator::MutateComputeLocation(), FloatImm(DataType::Float(64), 0.05)}, + {Mutator::MutateUnroll(), FloatImm(DataType::Float(64), 0.03)}, + {Mutator::MutateParallel(/*max_jobs_per_core=*/16), FloatImm(DataType::Float(64), 0.02)}}; +} + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { const auto* self = n.as(); @@ -104,6 +112,7 @@ TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultCUDA").set_body_typed(Mutator:: TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultCUDATensorCore") .set_body_typed(Mutator::DefaultCUDATensorCore); TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultHexagon").set_body_typed(Mutator::DefaultHexagon); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultMicro").set_body_typed(Mutator::DefaultMicro); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index dba523d094bfb..058f82424eee2 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -100,6 +100,15 @@ Array Postproc::DefaultHexagon() { }; } +Array Postproc::DefaultMicro() { + return Array{ + Postproc::DisallowDynamicLoop(), + Postproc::RewriteParallelVectorizeUnroll(), + Postproc::RewriteReductionBlock(), + Postproc::RewriteLayout(), + }; +} + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { const auto* self = n.as(); diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index e4f97c1fa6738..3bcd1ac37a477 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -251,6 +251,34 @@ Array ScheduleRule::DefaultHexagon() { }; } +Array ScheduleRule::DefaultMicro() { + return { + ScheduleRule::AutoInline( + /*into_producer=*/false, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/true, + /*require_injective=*/true, + /*require_ordered=*/true, + /*disallow_op=*/Array{"tir.exp"}), + ScheduleRule::MultiLevelTiling( + /*structure=*/"SSRSRS", + /*tile_binds=*/NullOpt, + /*max_innermost_factor=*/Integer(64), + /*vector_load_lens=*/NullOpt, + /*reuse_read=*/NullOpt, + /*reuse_write=*/ + Map{{"req", String("may")}, + {"levels", Array{1, 2}}, + {"scope", String("global")}}), + ScheduleRule::ParallelizeVectorizeUnroll( + /*max_jobs_per_core=*/16, + /*max_vectorize_extent=*/1, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_explicit=*/true), + }; +} + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { const auto* self = n.as(); @@ -279,6 +307,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultCUDATensorCore") .set_body_typed(ScheduleRule::DefaultCUDATensorCore); TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultHexagon") .set_body_typed(ScheduleRule::DefaultHexagon); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultMicro") + .set_body_typed(ScheduleRule::DefaultMicro); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 2d69727384a77..926f86cc4ff9a 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -57,6 +57,9 @@ String GetRuleKindFromTarget(const Target& target) { return "cuda"; } + if (target->kind->name == "c") { + return "c"; + } LOG(FATAL) << "Unsupported target: " << target; throw; } @@ -90,6 +93,10 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { default_sch_rules = ScheduleRule::DefaultVNNI(); default_postprocs = Postproc::DefaultVNNI(); default_mutator_probs = Mutator::DefaultVNNI(); + } else if (kind == "c") { + default_sch_rules = ScheduleRule::DefaultMicro(); + default_postprocs = Postproc::DefaultMicro(); + default_mutator_probs = Mutator::DefaultMicro(); } else { LOG(FATAL) << "Unsupported kind: " << kind; throw; diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 969aa630df399..6039423844e85 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -525,6 +525,8 @@ inline ScheduleRule GetDefaultAutoInline(const std::string& target_name) { rules = ScheduleRule::DefaultLLVM(); } else if (target_name == "hexagon") { rules = ScheduleRule::DefaultHexagon(); + } else if (target_name == "c") { + rules = ScheduleRule::DefaultMicro(); } else if (IsGPUTarget(target_name)) { rules = ScheduleRule::DefaultCUDA(); } else { diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 0a89198c985a4..1d8071774e9e1 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -54,6 +54,7 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_d decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n"; decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; decl_stream << "#include \n"; + decl_stream << "#include \n"; if (devices.find("ethos-u") != devices.end()) { decl_stream << "#include \n"; } diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index a87bb92c483b7..dc68ea3f86d70 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -309,6 +309,7 @@ TVM_REGISTER_TARGET_KIND("c", kDLCPU) .add_attr_option("march") .add_attr_option("workspace-byte-alignment") .add_attr_option("constants-byte-alignment") + .add_attr_option("num-cores") .set_default_keys({"cpu"}) .set_target_parser(tvm::target::parsers::cpu::ParseTarget);