From 0f54159712dd60de694715bfb4ac73ada795feba Mon Sep 17 00:00:00 2001
From: Junru Shao <junrushao1994@gmail.com>
Date: Wed, 29 Sep 2021 09:35:56 -0700
Subject: [PATCH] [Meta Schedule][M3b] Runner (#9111)

This PR is part of the meta schedule project (#8473) that adds the
asynchronous program runner interface, as well as a reference
implementation of RPCRunner. LocalRunner will be implemented with
PopenPool executor in a follow-up PR.

Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>

Address comments

Co-authored-by: Cody Yu <comaniac0422@gmail.com>

fix lint
---
 include/tvm/meta_schedule/runner.h            | 169 +++++-
 python/tvm/meta_schedule/__init__.py          |   5 +-
 .../meta_schedule/builder/local_builder.py    |  17 +-
 python/tvm/meta_schedule/runner/__init__.py   |   9 +-
 python/tvm/meta_schedule/runner/config.py     | 190 ++++++
 python/tvm/meta_schedule/runner/rpc_runner.py | 567 +++++++++++++++++
 python/tvm/meta_schedule/runner/runner.py     | 111 ++++
 python/tvm/meta_schedule/testing.py           |  74 +++
 python/tvm/meta_schedule/tune_context.py      |   4 +-
 python/tvm/meta_schedule/utils.py             |  37 +-
 src/meta_schedule/runner/runner.cc            |  45 +-
 .../unittest/test_meta_schedule_runner.py     | 571 ++++++++++++++++++
 12 files changed, 1776 insertions(+), 23 deletions(-)
 create mode 100644 python/tvm/meta_schedule/runner/config.py
 create mode 100644 python/tvm/meta_schedule/runner/rpc_runner.py
 create mode 100644 python/tvm/meta_schedule/testing.py
 create mode 100644 tests/python/unittest/test_meta_schedule_runner.py

diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h
index 36d07024559d..a45a4898d64a 100644
--- a/include/tvm/meta_schedule/runner.h
+++ b/include/tvm/meta_schedule/runner.h
@@ -20,16 +20,53 @@
 #define TVM_META_SCHEDULE_RUNNER_H_
 
 #include <tvm/ir/expr.h>
+#include <tvm/meta_schedule/arg_info.h>
 
 namespace tvm {
 namespace meta_schedule {
 
-/*! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */
+/*! \brief The runner's input. */
+class RunnerInputNode : public runtime::Object {
+ public:
+  /*! \brief The path to the built artifact. */
+  String artifact_path;
+  /*! \brief The type of device. */
+  String device_type;
+  /*! \brief The argument information. */
+  Array<ArgInfo> args_info;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("artifact_path", &artifact_path);
+    v->Visit("device_type", &device_type);
+    v->Visit("args_info", &args_info);
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.RunnerInput";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RunnerInputNode, runtime::Object);
+};
+
+/*!
+ * \brief Managed reference to RunnerInputNode
+ * \sa RunnerInputNode
+ */
+class RunnerInput : public runtime::ObjectRef {
+ public:
+  /*!
+   * \brief Constructor of RunnerInput
+   * \param artifact_path The path to the built artifact.
+   * \param device_type The type of device.
+   * \param args_info The argument information.
+   */
+  TVM_DLL explicit RunnerInput(String artifact_path, String device_type, Array<ArgInfo> args_info);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerInput, runtime::ObjectRef, RunnerInputNode);
+};
+
+/*! \brief The runner's output. */
 class RunnerResultNode : public runtime::Object {
  public:
-  /*! \brief The run time in seconds. If not None, error_msg should be None. */
+  /*! \brief The run time in seconds.*/
   Optional<Array<FloatImm>> run_secs;
-  /*! \brief The error message, if any. If not None, run_secs should be None. */
+  /*! \brief The error message, if any. */
   Optional<String> error_msg;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
@@ -48,14 +85,134 @@ class RunnerResultNode : public runtime::Object {
 class RunnerResult : public runtime::ObjectRef {
  public:
   /*!
-   * \brief Constructor for RunnerResult.
-   * \param run_secs The run time in seconds.
-   * \param error_msg The error message, if any.
+   * \brief Constructor
+   * \brief The run time in seconds.
+   * \brief The error message, if any.
    */
   TVM_DLL explicit RunnerResult(Optional<Array<FloatImm>> run_secs, Optional<String> error_msg);
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerResult, runtime::ObjectRef, RunnerResultNode);
 };
 
+/*!
+ * \brief A class to asynchronously fetch runner's output.
+ * \note The API design is consistent with python's concurrent.futures.Future:
+ * https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future
+ */
+class RunnerFutureNode : public runtime::Object {
+ public:
+  /*!
+   * \brief The function type to check whether the runner has finished.
+   * \return Whether the runner's output is ready.
+   */
+  using FDone = runtime::TypedPackedFunc<bool()>;
+  /*!
+   * \brief The function type to fetch runner output if it is ready.
+   * \return The runner's output.
+   */
+  using FResult = runtime::TypedPackedFunc<RunnerResult()>;
+
+  /*! \brief The packed function to check whether the runner has finished. */
+  FDone f_done;
+  /*! \brief The packed function to fetch runner output if it is ready. */
+  FResult f_result;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // `f_done` is not visited
+    // `f_result` is not visited
+  }
+
+  /*!
+   * \brief Check whether the runner has finished.
+   * \return A boolean indicating whether the runner has finished.
+   */
+  bool Done() const { return f_done(); }
+  /*!
+   * \brief Fetch the runner's output if it is ready.
+   * \return The runner's output.
+   */
+  RunnerResult Result() const { return f_result(); }
+
+  static constexpr const char* _type_key = "meta_schedule.RunnerFuture";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RunnerFutureNode, runtime::Object);
+};
+
+/*!
+ * \brief Managed reference to RunnerFutureNode
+ * \sa RunnerFutureNode
+ */
+class RunnerFuture : public runtime::ObjectRef {
+ public:
+  using FDone = RunnerFutureNode::FDone;
+  using FResult = RunnerFutureNode::FResult;
+
+  /*!
+   * \brief Constructor of RunnerFuture
+   * \param f_done The packed function to check whether the runner has finished.
+   * \param f_result The packed function to fetch runner output if it is ready.
+   */
+  TVM_DLL explicit RunnerFuture(FDone f_done, FResult f_result);
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerFuture, runtime::ObjectRef,
+                                                    RunnerFutureNode);
+};
+
+/*! \brief The abstract runner interface. */
+class RunnerNode : public runtime::Object {
+ public:
+  /*!
+   * \brief The function type to run the built artifacts and get runner futures.
+   * \param input The runner's inputs.
+   * \return The runner futures.
+   * \sa RunnerFuture
+   */
+  using FRun = runtime::TypedPackedFunc<Array<RunnerFuture>(Array<RunnerInput>)>;
+
+  /*! \brief Default destructor */
+  virtual ~RunnerNode() = default;
+
+  /*!
+   * \brief Run the built artifact and get runner futures.
+   * \param runner_inputs The runner's inputs.
+   * \return The runner futures.
+   */
+  virtual Array<RunnerFuture> Run(Array<RunnerInput> runner_inputs) = 0;
+
+  static constexpr const char* _type_key = "meta_schedule.Runner";
+  TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, runtime::Object);
+};
+
+/*!
+ * \brief Managed reference to RunnerNode
+ * \sa RunnerNode
+ */
+class Runner : public runtime::ObjectRef {
+ public:
+  using FRun = RunnerNode::FRun;
+
+  /*!
+   * \brief Create a runner with customized build method on the python-side.
+   * \param f_run The packed function to run the built artifacts and get runner futures.
+   * \return The runner created.
+   */
+  TVM_DLL static Runner PyRunner(FRun f_run);
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Runner, runtime::ObjectRef, RunnerNode);
+};
+
+/*! \brief An abstract runner with customized build method on the python-side. */
+class PyRunnerNode : public RunnerNode {
+ public:
+  /*! \brief The packed function to run the built artifacts and get runner futures. */
+  FRun f_run;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // `f_run` is not visited
+  }
+
+  Array<RunnerFuture> Run(Array<RunnerInput> runner_inputs) final { return f_run(runner_inputs); }
+
+  static constexpr const char* _type_key = "meta_schedule.PyRunner";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PyRunnerNode, RunnerNode);
+};
+
 }  // namespace meta_schedule
 }  // namespace tvm
 
diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py
index c22cc205bf35..2e280ef20ac3 100644
--- a/python/tvm/meta_schedule/__init__.py
+++ b/python/tvm/meta_schedule/__init__.py
@@ -16,10 +16,9 @@
 # under the License.
 """Package `tvm.meta_schedule`. The meta schedule infrastructure."""
 from . import arg_info
-from . import builder
 from . import database
+from . import builder
+from . import runner
 from . import space_generator
 from . import search_strategy
-from . import runner
-from .database import TuningRecord
 from .tune_context import TuneContext
diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py
index cefe5ec50cad..99dfaea56090 100644
--- a/python/tvm/meta_schedule/builder/local_builder.py
+++ b/python/tvm/meta_schedule/builder/local_builder.py
@@ -48,11 +48,20 @@ class LocalBuilder(PyBuilder):
     Attributes
     ----------
     T_BUILD : typing._GenericAlias
-        The signature of the build function `f_build`, which is
-        `Callable[[IRModule, Target], Module]`
+        The signature of the function `f_build`, which is
+
+        .. code-block:: python
+
+        def default_build(mod: IRModule, target: Target) -> Module:
+            ...
+
     T_EXPORT : typing._GenericAlias
-        The signature of the build function `f_export`, which is
-        `Callable[[Module], str]`
+        The signature of the function `f_export`, which is
+
+        .. code-block:: python
+
+        def default_export(mod: Module) -> str:
+            ...
 
     Note
     ----
diff --git a/python/tvm/meta_schedule/runner/__init__.py b/python/tvm/meta_schedule/runner/__init__.py
index 65d2ef04e04c..47f4557e1d3a 100644
--- a/python/tvm/meta_schedule/runner/__init__.py
+++ b/python/tvm/meta_schedule/runner/__init__.py
@@ -14,5 +14,10 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""meta_schedule.runner"""
-from .runner import RunnerResult
+"""
+The tvm.meta_schedule.runner package.
+Meta Schedule runners that runs an artifact either locally or through the RPC interface
+"""
+from .config import EvaluatorConfig, RPCConfig
+from .rpc_runner import RPCRunner
+from .runner import PyRunner, Runner, RunnerFuture, RunnerInput, RunnerResult
diff --git a/python/tvm/meta_schedule/runner/config.py b/python/tvm/meta_schedule/runner/config.py
new file mode 100644
index 000000000000..712766de99c1
--- /dev/null
+++ b/python/tvm/meta_schedule/runner/config.py
@@ -0,0 +1,190 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Configurations for measurements in the runner"""
+import os
+from threading import Thread
+from typing import NamedTuple, Optional, Union
+
+from tvm import rpc
+
+
+class EvaluatorConfig(NamedTuple):
+    """Config Details of Evaluator
+
+    Parameters
+    ----------
+    number: int
+        The number of runs.
+    repeat: int
+        The number of times to repeat in each run.
+    min_repeat_ms: int
+        Minimum repeat time in ms. if the execution latency is too short,
+        increase the number of runs to the given time (in ms) to reduce the measurement error.
+    enable_cpu_cache_flush: bool
+        Whether to flush the cache on CPU.
+
+    Note
+    ----
+    The total number of actual executions is 1+number*repeat because we would warm up 1 time before
+    actual run. The number of runs would be increased if run time is below min_repeat_ms.
+    """
+
+    number: int = 3
+    repeat: int = 1
+    min_repeat_ms: int = 40
+    enable_cpu_cache_flush: bool = False
+
+    @staticmethod
+    def _normalized(config: Optional["EvaluatorConfig"]) -> "EvaluatorConfig":
+        if config is None:
+            return EvaluatorConfig()
+        config = EvaluatorConfig(
+            number=config.number,
+            repeat=config.repeat,
+            min_repeat_ms=config.min_repeat_ms,
+            enable_cpu_cache_flush=config.enable_cpu_cache_flush,
+        )
+        return config
+
+
+class RPCConfig(NamedTuple):
+    """RPC configuration
+
+    Parameters
+    ----------
+    tracker_host: str
+        Host of the RPC Tracker
+    tracker_port: int
+        Port of the RPC Tracker
+    tracker_key: str
+        Key of the Tracker
+    session_timeout_sec: float
+        Timeout of the RPC session
+    session_priority: int
+        Priority of the RPC session
+    """
+
+    tracker_host: Optional[str] = None
+    tracker_port: Union[None, int, str] = None
+    tracker_key: Optional[str] = None
+    session_priority: int = 1
+    session_timeout_sec: int = 10
+
+    def _sanity_check(self) -> None:
+        err_str = (
+            "RPCConfig.{0} is not provided. Please provide it explicitly,"
+            "or set environment variable {1}"
+        )
+        if self.tracker_host is None:
+            raise ValueError(err_str.format("tracker_host", "TVM_TRACKER_HOST"))
+        if self.tracker_port is None:
+            raise ValueError(err_str.format("tracker_port", "TVM_TRACKER_PORT"))
+        if self.tracker_key is None:
+            raise ValueError(err_str.format("tracker_key", "TVM_TRACKER_KEY"))
+
+    @staticmethod
+    def _normalized(config: Optional["RPCConfig"]) -> "RPCConfig":
+        if config is None:
+            config = RPCConfig()
+        config = RPCConfig(
+            tracker_host=config.tracker_host or os.environ.get("TVM_TRACKER_HOST", None),
+            tracker_port=config.tracker_port or os.environ.get("TVM_TRACKER_PORT", None),
+            tracker_key=config.tracker_key or os.environ.get("TVM_TRACKER_KEY", None),
+            session_priority=config.session_priority,
+            session_timeout_sec=config.session_timeout_sec,
+        )
+        config._sanity_check()  # pylint: disable=protected-access
+        return config
+
+    def connect_tracker(self) -> rpc.TrackerSession:
+        """Connect to the tracker
+
+        Returns
+        -------
+        tracker : TrackerSession
+            The connected tracker session
+        """
+        tracker: Optional[rpc.TrackerSession] = None
+
+        def _connect():
+            nonlocal tracker
+            tracker = rpc.connect_tracker(self.tracker_host, self.tracker_port)
+
+        t = Thread(target=_connect)
+        t.start()
+        t.join(self.session_timeout_sec)
+        if t.is_alive() or tracker is None:
+            raise ValueError(
+                "Unable to connect to the tracker using the following configuration:\n"
+                f"    tracker host: {self.tracker_host}\n"
+                f"    tracker port: {self.tracker_port}\n"
+                f"    timeout (sec): {self.session_timeout_sec}\n"
+                "Please check the tracker status via the following command:\n"
+                "     python3 -m tvm.exec.query_rpc_tracker "
+                f"--host {self.tracker_host} --port {self.tracker_port}"
+            )
+        return tracker
+
+    def connect_server(self) -> rpc.RPCSession:
+        """Connect to the server
+
+        Returns
+        -------
+        session : RPCSession
+            The connected rpc session
+        """
+        tracker = self.connect_tracker()
+        session: rpc.RPCSession = tracker.request(
+            key=self.tracker_key,
+            priority=self.session_priority,
+            session_timeout=self.session_timeout_sec,
+        )
+        return session
+
+    def count_num_servers(self, allow_missing=True) -> int:
+        """Count the number of servers available in the tracker
+
+        Parameters
+        ----------
+        allow_missing : bool
+            Whether to allow no server to be found.
+
+        Returns
+        -------
+        num_servers : int
+            The number of servers
+        """
+        tracker = self.connect_tracker()
+        tracker_summary = tracker.summary()
+        result: int = 0
+        for item in tracker_summary["server_info"]:
+            _, item_key = item["key"].split(":")
+            if item_key == self.tracker_key:
+                result += 1
+        if result == 0 and not allow_missing:
+            raise ValueError(
+                "Unable to find servers with the specific key using the following configuration:\n"
+                f"    tracker host: {self.tracker_host}\n"
+                f"    tracker port: {self.tracker_port}\n"
+                f"    tracker key: {self.tracker_key}\n"
+                f"    timeout (sec): {self.session_timeout_sec}\n"
+                "Please check the tracker status via the following command:\n"
+                "     python3 -m tvm.exec.query_rpc_tracker "
+                f"--host {self.tracker_host} --port {self.tracker_port}\n"
+                f'and look for key: "{self.tracker_key}"'
+            )
+        return result
diff --git a/python/tvm/meta_schedule/runner/rpc_runner.py b/python/tvm/meta_schedule/runner/rpc_runner.py
new file mode 100644
index 000000000000..d20e1707fcec
--- /dev/null
+++ b/python/tvm/meta_schedule/runner/rpc_runner.py
@@ -0,0 +1,567 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""RPC Runner"""
+import concurrent.futures
+from contextlib import contextmanager
+import itertools
+import os.path as osp
+from typing import Any, Callable, Dict, List, Optional, Union
+
+from tvm.contrib.popen_pool import PopenPoolExecutor
+from tvm.rpc import RPCSession
+from tvm.runtime import Device, Module, ndarray
+
+from ..utils import (
+    get_global_func_on_rpc_session,
+    get_global_func_with_default_on_worker,
+)
+from .config import EvaluatorConfig, RPCConfig
+from .runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult
+
+
+class RPCRunnerFuture(RunnerFuture):
+    """RPC based runner future
+
+    Parameters
+    ----------
+    future: concurrent.futures.Future
+        The concurrent function to check when the function is done and to return the result.
+    timeout_sec: float
+        The timeout in seconds.
+    """
+
+    future: concurrent.futures.Future
+    timeout_sec: float
+
+    def __init__(self, future: concurrent.futures.Future, timeout_sec: float) -> None:
+        """Constructor
+
+        Parameters
+        ----------
+        future: concurrent.futures.Future
+            The concurrent function to check when the function is done and to return the result.
+        timeout_sec: float
+            The timeout in seconds.
+        """
+        super().__init__()
+        self.future = future
+        self.timeout_sec = timeout_sec
+
+    def done(self) -> bool:
+        return self.future.done()
+
+    def result(self) -> RunnerResult:
+        try:
+            run_secs: List[float] = self.future.result()
+        except TimeoutError as exception:
+            return RunnerResult(
+                None,
+                error_msg=f"RPCRunner: Timeout, killed after {self.timeout_sec} seconds",
+            )
+        except Exception as exception:  # pylint: disable=broad-except
+            return RunnerResult(
+                None,
+                error_msg="RPCRunner: An exception occurred\n" + str(exception),
+            )
+        return RunnerResult(run_secs, None)
+
+
+T_ARG_INFO_JSON_OBJ = List[Any]  # pylint: disable=invalid-name
+T_ARG_INFO_JSON_OBJ_LIST = List[T_ARG_INFO_JSON_OBJ]  # pylint: disable=invalid-name
+T_ARGUMENT = Any  # pylint: disable=invalid-name
+T_ARGUMENT_LIST = List[T_ARGUMENT]  # pylint: disable=invalid-name
+
+
+class RPCRunner(PyRunner):
+    """RPC based runner
+
+    Parameters
+    ----------
+    rpc_config: RPCConfig
+        The rpc configuration.
+    evaluator_config: EvaluatorConfig
+        The evaluator configuration.
+    cooldown_sec: float
+        The cooldown in seconds. TODO(@junrushao1994,@zxybazh): This is not used yet.
+    alloc_repeat: int
+        The number of times to repeat the allocation.
+    f_create_session: Optional[str, Callable]
+        The function name to create the session or the function itself.
+    f_upload_module: Optional[str, Callable]
+        The function name to upload the module or the function itself.
+    f_alloc_argument: Optional[str, Callable]
+        The function name to allocate the arguments or the function itself.
+    f_run_evaluator: Optional[str, Callable]
+        The function name to run the evaluator or the function itself.
+    f_cleanup: Optional[str, Callable]
+        The function name to cleanup the session or the function itself.
+    pool: PopenPoolExecutor
+        The popen pool executor.
+
+    Attributes
+    ----------
+    T_CREATE_SESSION : typing._GenericAlias
+        The signature of the function `f_create_session`, which is:
+
+        .. code-block:: python
+
+        def default_create_session(rpc_config: RPCConfig) -> RPCSession:
+            ...
+
+    T_UPLOAD_MODULE : typing._GenericAlias
+        The signature of the function `f_upload_module`, which is:
+
+        .. code-block:: python
+
+        def default_upload_module(
+            session: RPCSession,
+            local_path: str,
+            remote_path: str,
+        ) -> Module:
+            ...
+
+    T_ALLOC_ARGUMENT : typing._GenericAlias
+        The signature of the function `f_alloc_argument`, which is:
+
+        .. code-block:: python
+
+        def default_alloc_argument(
+            session: RPCSession,
+            device: Device,
+            args_info: T_ARG_INFO_JSON_OBJ_LIST,
+            alloc_repeat: int,
+        ) -> List[T_ARGUMENT_LIST]:
+            ...
+
+    T_RUN_EVALUATOR : typing._GenericAlias
+        The signature of the function `f_run_evaluator`, which is:
+
+        .. code-block:: python
+
+        def default_run_evaluator(
+            session: RPCSession,
+            rt_mod: Module,
+            device: Device,
+            evaluator_config: EvaluatorConfig,
+            repeated_args: List[T_ARGUMENT_LIST],
+        ) -> List[float]:
+            ...
+
+    T_CLEANUP : typing._GenericAlias
+        The signature of the function `f_cleanup`, which is:
+
+        .. code-block:: python
+
+        def default_cleanup(
+            session: Optional[RPCSession],
+            remote_path: Optional[str],
+        ) -> None:
+            ...
+    """
+
+    T_CREATE_SESSION = Callable[
+        [RPCConfig],  # The RPC configuration
+        RPCSession,  # The RPC Session
+    ]
+    T_UPLOAD_MODULE = Callable[
+        [
+            RPCSession,  # The RPC Session
+            str,  # local path to the artifact
+            str,  # remote path to the artifact
+        ],
+        Module,  # the Module opened on the remote
+    ]
+    T_ALLOC_ARGUMENT = Callable[
+        [
+            RPCSession,  # The RPC Session
+            Device,  # The device on the remote
+            T_ARG_INFO_JSON_OBJ_LIST,  # The metadata information of the arguments to be allocated
+            int,  # The number of repeated allocations to be done
+        ],
+        List[T_ARGUMENT_LIST],  # A list of argument lists
+    ]
+    T_RUN_EVALUATOR = Callable[
+        [
+            RPCSession,  # The RPC Session
+            Module,  # The Module opened on the remote
+            Device,  # The device on the remote
+            EvaluatorConfig,  # The evaluator configuration
+            List[T_ARGUMENT_LIST],  # A list of argument lists
+        ],
+        List[float],  # A list of running time
+    ]
+    T_CLEANUP = Callable[
+        [
+            Optional[RPCSession],  # The RPC Session to be cleaned up
+            Optional[str],  # remote path to the artifact
+        ],
+        None,
+    ]
+
+    rpc_config: RPCConfig
+    evaluator_config: EvaluatorConfig
+    cooldown_sec: float
+    alloc_repeat: int
+
+    f_create_session: Union[T_CREATE_SESSION, str, None]
+    f_upload_module: Union[T_UPLOAD_MODULE, str, None]
+    f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None]
+    f_run_evaluator: Union[T_RUN_EVALUATOR, str, None]
+    f_cleanup: Union[T_CLEANUP, str, None]
+
+    pool: PopenPoolExecutor
+
+    def __init__(
+        self,
+        rpc_config: Optional[RPCConfig] = None,
+        evaluator_config: Optional[EvaluatorConfig] = None,
+        cooldown_sec: float = 0.0,
+        alloc_repeat: int = 1,
+        f_create_session: Union[T_CREATE_SESSION, str, None] = None,
+        f_upload_module: Union[T_UPLOAD_MODULE, str, None] = None,
+        f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None] = None,
+        f_run_evaluator: Union[T_RUN_EVALUATOR, str, None] = None,
+        f_cleanup: Union[T_CLEANUP, str, None] = None,
+        max_connections: Optional[int] = None,
+        initializer: Optional[Callable[[], None]] = None,
+    ) -> None:
+        """Constructor
+
+        Parameters
+        ----------
+        rpc_config: RPCConfig
+            The rpc configuration.
+        evaluator_config: EvaluatorConfig
+            The evaluator configuration.
+        cooldown_sec: float
+            The cooldown in seconds.
+        alloc_repeat: int
+            The number of times to random fill the allocation.
+        f_create_session: Union[T_CREATE_SESSION, str, None]
+            The function name to create the session or the function itself.
+        f_upload_module: Union[T_UPLOAD_MODULE, str, None]
+            The function name to upload the module or the function itself.
+        f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None]
+            The function name to allocate the arguments or the function itself.
+        f_run_evaluator: Union[T_RUN_EVALUATOR, str, None]
+            The function name to run the evaluator or the function itself.
+        f_cleanup: Union[T_CLEANUP, str, None]
+            The function name to cleanup the session or the function itself.
+        max_connections: Optional[int]
+            The maximum number of connections.
+        initializer: Optional[Callable[[], None]]
+            The initializer function.
+        """
+        super().__init__()
+        self.rpc_config = RPCConfig._normalized(rpc_config)
+        self.evaluator_config = EvaluatorConfig._normalized(evaluator_config)
+        self.cooldown_sec = cooldown_sec
+        self.alloc_repeat = alloc_repeat
+        self.f_create_session = f_create_session
+        self.f_upload_module = f_upload_module
+        self.f_alloc_argument = f_alloc_argument
+        self.f_run_evaluator = f_run_evaluator
+        self.f_cleanup = f_cleanup
+
+        num_servers = self.rpc_config.count_num_servers(allow_missing=False)
+        if max_connections is None:
+            max_connections = num_servers
+        else:
+            max_connections = min(max_connections, num_servers)
+
+        self.pool = PopenPoolExecutor(
+            max_workers=max_connections,
+            timeout=rpc_config.session_timeout_sec,
+            initializer=initializer,
+        )
+        self._sanity_check()
+
+    def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
+        results: List[RunnerFuture] = []
+        for runner_input in runner_inputs:
+            future = RPCRunnerFuture(
+                future=self.pool.submit(
+                    RPCRunner._worker_func,
+                    self.f_create_session,
+                    self.f_upload_module,
+                    self.f_alloc_argument,
+                    self.f_run_evaluator,
+                    self.f_cleanup,
+                    self.rpc_config,
+                    self.evaluator_config,
+                    self.alloc_repeat,
+                    str(runner_input.artifact_path),
+                    str(runner_input.device_type),
+                    tuple(arg_info.as_json() for arg_info in runner_input.args_info),
+                ),
+                timeout_sec=self.rpc_config.session_timeout_sec,
+            )
+            results.append(future)
+        return results
+
+    def _sanity_check(self) -> None:
+        def _check(
+            f_create_session,
+            f_upload_module,
+            f_alloc_argument,
+            f_run_evaluator,
+            f_cleanup,
+        ) -> None:
+            get_global_func_with_default_on_worker(name=f_create_session, default=None)
+            get_global_func_with_default_on_worker(name=f_upload_module, default=None)
+            get_global_func_with_default_on_worker(name=f_alloc_argument, default=None)
+            get_global_func_with_default_on_worker(name=f_run_evaluator, default=None)
+            get_global_func_with_default_on_worker(name=f_cleanup, default=None)
+
+        value = self.pool.submit(
+            _check,
+            self.f_create_session,
+            self.f_upload_module,
+            self.f_alloc_argument,
+            self.f_run_evaluator,
+            self.f_cleanup,
+        )
+        value.result()
+
+    @staticmethod
+    def _worker_func(
+        _f_create_session: Union[T_CREATE_SESSION, str, None],
+        _f_upload_module: Union[T_UPLOAD_MODULE, str, None],
+        _f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None],
+        _f_run_evaluator: Union[T_RUN_EVALUATOR, str, None],
+        _f_cleanup: Union[T_CLEANUP, str, None],
+        rpc_config: RPCConfig,
+        evaluator_config: EvaluatorConfig,
+        alloc_repeat: int,
+        artifact_path: str,
+        device_type: str,
+        args_info: T_ARG_INFO_JSON_OBJ_LIST,
+    ) -> List[float]:
+        # Step 0. Get the registered functions
+        f_create_session: RPCRunner.T_CREATE_SESSION = get_global_func_with_default_on_worker(
+            _f_create_session, default_create_session
+        )
+        f_upload_module: RPCRunner.T_UPLOAD_MODULE = get_global_func_with_default_on_worker(
+            _f_upload_module, default_upload_module
+        )
+        f_alloc_argument: RPCRunner.T_ALLOC_ARGUMENT = get_global_func_with_default_on_worker(
+            _f_alloc_argument, default_alloc_argument
+        )
+        f_run_evaluator: RPCRunner.T_RUN_EVALUATOR = get_global_func_with_default_on_worker(
+            _f_run_evaluator, default_run_evaluator
+        )
+        f_cleanup: RPCRunner.T_CLEANUP = get_global_func_with_default_on_worker(
+            _f_cleanup, default_cleanup
+        )
+        # Managed resources
+        session: Optional[RPCSession] = None
+        remote_path: Optional[str] = None
+
+        @contextmanager
+        def resource_handler():
+            try:
+                yield
+            finally:
+                # Step 5. Clean up
+                f_cleanup(session, remote_path)
+
+        with resource_handler():
+            # Step 1. Create session
+            session = f_create_session(rpc_config)
+            device = session.device(dev_type=device_type, dev_id=0)
+            # Step 2. Upload the module
+            _, remote_path = osp.split(artifact_path)
+            local_path: str = artifact_path
+            rt_mod: Module = f_upload_module(session, local_path, remote_path)
+            # Step 3: Allocate input arguments
+            repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument(
+                session,
+                device,
+                args_info,
+                alloc_repeat,
+            )
+            # Step 4: Run time_evaluator
+            costs: List[float] = f_run_evaluator(
+                session,
+                rt_mod,
+                device,
+                evaluator_config,
+                repeated_args,
+            )
+        return costs
+
+
+def default_create_session(rpc_config: RPCConfig) -> RPCSession:
+    """Default function to create the session
+
+    Parameters
+    ----------
+    rpc_config : RPCConfig
+        The configuration of the RPC session
+
+    Returns
+    -------
+    session : RPCSession
+        The created rpc session
+    """
+    return rpc_config.connect_server()
+
+
+def default_upload_module(
+    session: RPCSession,
+    local_path: str,
+    remote_path: str,
+) -> Module:
+    """Default function to upload the module
+
+    Parameters
+    ----------
+    session: RPCSession
+        The session to upload the module
+    local_path: str
+        The local path of the module
+    remote_path: str
+        The remote path to place the module
+
+    Returns
+    -------
+    rt_mod : Module
+        The runtime module
+    """
+    session.upload(local_path, remote_path)
+    rt_mod: Module = session.load_module(remote_path)
+    return rt_mod
+
+
+def default_alloc_argument(
+    session: RPCSession,
+    device: Device,
+    args_info: T_ARG_INFO_JSON_OBJ_LIST,
+    alloc_repeat: int,
+) -> List[T_ARGUMENT_LIST]:
+    """Default function to allocate the arguments
+
+    Parameters
+    ----------
+    session: RPCSession
+        The session to allocate the arguments
+    device: Device
+        The device to allocate the arguments
+    alloc_repeat: int
+        The number of times to repeat the allocation
+    args_info: PyArgsInfo
+        The arguments info
+
+    Returns
+    -------
+    repeated_args: List[Args]
+        The allocation args
+    """
+    f_random_fill = get_global_func_on_rpc_session(
+        session,
+        "tvm.contrib.random.random_fill",
+        "Please make sure 'USE_RANDOM' is turned ON in the config.cmake on the RPC server.",
+    )
+
+    def alloc_tensor(_, dtype, shape) -> ndarray.NDArray:
+        arg = ndarray.empty(shape=shape, dtype=dtype, device=device)
+        f_random_fill(arg)
+        return arg
+
+    def alloc_fail(*arg_info) -> None:
+        raise NotImplementedError(arg_info)
+
+    dispatcher: Dict[Any, Callable] = {
+        "TENSOR": alloc_tensor,
+        None: alloc_fail,
+    }
+
+    repeated_args: List[T_ARGUMENT_LIST] = []
+    for _ in range(alloc_repeat):
+        args: T_ARGUMENT_LIST = []
+        arg_info: T_ARG_INFO_JSON_OBJ
+        for arg_info in args_info:
+            arg_type = arg_info[0]
+            arg: Any = dispatcher.get(arg_type, None)(*arg_info)
+            args.append(arg)
+        repeated_args.append(args)
+    return repeated_args
+
+
+def default_run_evaluator(
+    session: RPCSession,  # pylint: disable=unused-argument
+    rt_mod: Module,
+    device: Device,
+    evaluator_config: EvaluatorConfig,
+    repeated_args: List[T_ARGUMENT_LIST],
+) -> List[float]:
+    """Default function to run the evaluator
+
+    Parameters
+    ----------
+    session: RPCSession
+        The session to run the evaluator
+    rt_mod: Module
+        The runtime module
+    device: Device
+        The device to run the evaluator
+    evaluator_config: EvaluatorConfig
+        The evaluator config
+    repeated_args: List[Args]
+        The repeated arguments
+
+    Returns
+    -------
+    costs: List[float]
+        The evaluator results
+    """
+    evaluator = rt_mod.time_evaluator(
+        func_name=rt_mod.entry_name,
+        dev=device,
+        number=evaluator_config.number,
+        repeat=evaluator_config.repeat,
+        min_repeat_ms=evaluator_config.min_repeat_ms,
+        f_preproc="cache_flush_cpu_non_first_arg"
+        if evaluator_config.enable_cpu_cache_flush
+        else "",
+    )
+    repeated_costs: List[List[float]] = []
+    for args in repeated_args:
+        device.sync()
+        profile_result = evaluator(*args)
+        repeated_costs.append(profile_result.results)
+    costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
+    return costs
+
+
+def default_cleanup(
+    session: Optional[RPCSession],
+    remote_path: Optional[str],
+) -> None:
+    """Default function to clean up the session
+
+    Parameters
+    ----------
+    session: RPCSession
+        The session to clean up
+    remote_path: str
+        The remote path to clean up
+    """
+    if session is not None and remote_path is not None:
+        session.remove(remote_path)
+        session.remove(remote_path + ".so")
+        session.remove("")
diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py
index b756c6e6b011..9f7be8ea4af4 100644
--- a/python/tvm/meta_schedule/runner/runner.py
+++ b/python/tvm/meta_schedule/runner/runner.py
@@ -21,6 +21,50 @@
 from tvm.runtime import Object
 
 from .. import _ffi_api
+from ..arg_info import ArgInfo
+
+
+@register_object("meta_schedule.RunnerInput")
+class RunnerInput(Object):
+    """The runner's input
+
+    Parameters
+    ----------
+    artifact_path : str
+        The path to the built artifact.
+    device_type : str
+        The device type.
+    args_info : List[ArgInfo]
+        The argument information.
+    """
+
+    artifact_path: str
+    device_type: str
+    args_info: List[ArgInfo]
+
+    def __init__(
+        self,
+        artifact_path: str,
+        device_type: str,
+        args_info: List[ArgInfo],
+    ) -> None:
+        """Constructor
+
+        Parameters
+        ----------
+        artifact_path : str
+            The path to the built artifact.
+        device_type : str
+            The device type.
+        args_info : List[ArgInfo]
+            The argument information.
+        """
+        self.__init_handle_by_constructor__(
+            _ffi_api.RunnerInput,  # type: ignore # pylint: disable=no-member
+            artifact_path,
+            device_type,
+            args_info,
+        )
 
 
 @register_object("meta_schedule.RunnerResult")
@@ -57,3 +101,70 @@ def __init__(
             run_secs,
             error_msg,
         )
+
+
+@register_object("meta_schedule.RunnerFuture")
+class RunnerFuture(Object):
+    """A class to fetch asynchronous runner's output."""
+
+    def __init__(self) -> None:
+        """Constructor"""
+
+        def f_done():
+            return self.done()
+
+        def f_result():
+            return self.result()
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.RunnerFuture,  # type: ignore # pylint: disable=no-member
+            f_done,
+            f_result,
+        )
+
+    def done(self) -> bool:
+        """Check whether the runner has finished."""
+        raise NotImplementedError
+
+    def result(self) -> RunnerResult:
+        """Fetch the runner's output if it is ready."""
+        raise NotImplementedError
+
+
+@register_object("meta_schedule.Runner")
+class Runner(Object):
+    """The abstract runner interface"""
+
+    def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
+        """Run the built artifact and get runner futures.
+
+        Parameters
+        ----------
+        runner_inputs : List[RunnerInput]
+            The inputs to the runner.
+
+        Returns
+        -------
+        runner_futures: List[RunnerFuture]
+            The runner futures.
+        """
+        return _ffi_api.RunnerRun(self, runner_inputs)  # type: ignore # pylint: disable=no-member
+
+
+@register_object("meta_schedule.PyRunner")
+class PyRunner(Runner):
+    """An abstract runner with customized build method on the python-side."""
+
+    def __init__(self) -> None:
+        """Constructor"""
+
+        def f_run(runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
+            return self.run(runner_inputs)
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.RunnerPyRunner,  # type: ignore # pylint: disable=no-member
+            f_run,
+        )
+
+    def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
+        raise NotImplementedError
diff --git a/python/tvm/meta_schedule/testing.py b/python/tvm/meta_schedule/testing.py
new file mode 100644
index 000000000000..4caaeb7553cc
--- /dev/null
+++ b/python/tvm/meta_schedule/testing.py
@@ -0,0 +1,74 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Testing utilities in meta schedule"""
+import time
+
+from tvm.rpc.tracker import Tracker
+from tvm.rpc.server import Server
+
+
+class LocalRPC:
+    """A pair of RPC tracker/server running locally
+
+    Parameters
+    ----------
+    tracker_host : str
+        The host URL of the tracker
+    tracker_port : int
+        The port of the tracker
+    tracker_key: str
+        The key used in the tracker to refer to a worker
+    """
+
+    tracker_host: str
+    tracker_port: int
+    tracker_key: str
+
+    def __init__(
+        self,
+        tracker_key: str = "key",
+        silent: bool = False,
+        no_fork: bool = False,
+    ) -> None:
+        self.tracker = Tracker(
+            silent=silent,
+            port=9190,
+            port_end=12345,
+        )
+        time.sleep(0.5)
+        self.server = Server(
+            host="0.0.0.0",
+            is_proxy=False,
+            tracker_addr=(self.tracker.host, self.tracker.port),
+            key=tracker_key,
+            silent=silent,
+            no_fork=no_fork,
+            port=9190,
+            port_end=12345,
+        )
+        self.tracker_host = self.tracker.host
+        self.tracker_port = self.tracker.port
+        self.tracker_key = tracker_key
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, _type, _value, _traceback):
+        if hasattr(self, "server"):
+            del self.server
+        if hasattr(self, "tracker"):
+            del self.tracker
diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py
index 4c83b9afa289..9c41b4d575da 100644
--- a/python/tvm/meta_schedule/tune_context.py
+++ b/python/tvm/meta_schedule/tune_context.py
@@ -19,10 +19,10 @@
 from typing import Optional, TYPE_CHECKING
 
 from tvm import IRModule
+from tvm._ffi import register_object
+from tvm.meta_schedule.utils import cpu_count
 from tvm.runtime import Object
 from tvm.target import Target
-from tvm.meta_schedule.utils import cpu_count
-from tvm._ffi import register_object
 
 from . import _ffi_api
 
diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py
index e710b0ed06f3..5f536994a9fd 100644
--- a/python/tvm/meta_schedule/utils.py
+++ b/python/tvm/meta_schedule/utils.py
@@ -18,14 +18,14 @@
 import json
 import os
 import shutil
-from typing import Any, Callable, List, Union
+from typing import Any, Callable, List, Optional, Union
 
 import psutil
-
 from tvm._ffi import get_global_func, register_func
 from tvm.error import TVMError
 from tvm.ir import Array, Map
-from tvm.runtime import String
+from tvm.rpc import RPCSession
+from tvm.runtime import PackedFunc, String
 from tvm.tir import FloatImm, IntImm
 
 
@@ -95,6 +95,37 @@ def get_global_func_with_default_on_worker(
         ) from error
 
 
+def get_global_func_on_rpc_session(
+    session: RPCSession,
+    name: str,
+    extra_error_msg: Optional[str] = None,
+) -> PackedFunc:
+    """Get a PackedFunc from the global registry from an RPCSession.
+
+    Parameters
+    ----------
+    session : RPCSession
+        The RPCSession to be retrieved from
+    name : str
+        The name of the PackedFunc
+    extra_error_msg : Optional[str]
+        Extra information to provide in the error message
+
+    Returns
+    -------
+    result : PackedFunc
+        The result
+    """
+    try:
+        result = session.get_function(name)
+    except AttributeError as error:
+        error_msg = f'Unable to find function "{name}" on the remote RPC server.'
+        if extra_error_msg:
+            error_msg = f"{error_msg} {extra_error_msg}"
+        raise AttributeError(error_msg) from error
+    return result
+
+
 @register_func("meta_schedule.remove_build_dir")
 def remove_build_dir(artifact_path: str) -> None:
     """Clean up the build directory"""
diff --git a/src/meta_schedule/runner/runner.cc b/src/meta_schedule/runner/runner.cc
index 8f509bdd7b84..800a76f21e65 100644
--- a/src/meta_schedule/runner/runner.cc
+++ b/src/meta_schedule/runner/runner.cc
@@ -16,13 +16,19 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-#include <tvm/runtime/registry.h>
-
 #include "../utils.h"
 
 namespace tvm {
 namespace meta_schedule {
 
+RunnerInput::RunnerInput(String artifact_path, String device_type, Array<ArgInfo> args_info) {
+  ObjectPtr<RunnerInputNode> n = make_object<RunnerInputNode>();
+  n->artifact_path = artifact_path;
+  n->device_type = device_type;
+  n->args_info = args_info;
+  this->data_ = n;
+}
+
 RunnerResult::RunnerResult(Optional<Array<FloatImm>> run_secs, Optional<String> error_msg) {
   ObjectPtr<RunnerResultNode> n = make_object<RunnerResultNode>();
   n->run_secs = run_secs;
@@ -30,12 +36,45 @@ RunnerResult::RunnerResult(Optional<Array<FloatImm>> run_secs, Optional<String>
   this->data_ = n;
 }
 
-TVM_REGISTER_NODE_TYPE(RunnerResultNode);
+RunnerFuture::RunnerFuture(RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) {
+  ObjectPtr<RunnerFutureNode> n = make_object<RunnerFutureNode>();
+  n->f_done = f_done;
+  n->f_result = f_result;
+  this->data_ = n;
+}
 
+Runner Runner::PyRunner(Runner::FRun f_run) {
+  ObjectPtr<PyRunnerNode> n = make_object<PyRunnerNode>();
+  n->f_run = f_run;
+  return Runner(n);
+}
+
+/******** FFI ********/
+
+TVM_REGISTER_NODE_TYPE(RunnerInputNode);
+TVM_REGISTER_NODE_TYPE(RunnerResultNode);
+TVM_REGISTER_NODE_TYPE(RunnerFutureNode);
+TVM_REGISTER_OBJECT_TYPE(RunnerNode);
+TVM_REGISTER_NODE_TYPE(PyRunnerNode);
+TVM_REGISTER_GLOBAL("meta_schedule.RunnerInput")
+    .set_body_typed([](String artifact_path, String device_type,
+                       Array<ArgInfo> args_info) -> RunnerInput {
+      return RunnerInput(artifact_path, device_type, args_info);
+    });
 TVM_REGISTER_GLOBAL("meta_schedule.RunnerResult")
     .set_body_typed([](Array<FloatImm> run_secs, Optional<String> error_msg) -> RunnerResult {
       return RunnerResult(run_secs, error_msg);
     });
+TVM_REGISTER_GLOBAL("meta_schedule.RunnerFuture")
+    .set_body_typed([](RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) -> RunnerFuture {
+      return RunnerFuture(f_done, f_result);
+    });
+TVM_REGISTER_GLOBAL("meta_schedule.RunnerFutureDone")
+    .set_body_method<RunnerFuture>(&RunnerFutureNode::Done);
+TVM_REGISTER_GLOBAL("meta_schedule.RunnerFutureResult")
+    .set_body_method<RunnerFuture>(&RunnerFutureNode::Result);
+TVM_REGISTER_GLOBAL("meta_schedule.RunnerRun").set_body_method<Runner>(&RunnerNode::Run);
+TVM_REGISTER_GLOBAL("meta_schedule.RunnerPyRunner").set_body_typed(Runner::PyRunner);
 
 }  // namespace meta_schedule
 }  // namespace tvm
diff --git a/tests/python/unittest/test_meta_schedule_runner.py b/tests/python/unittest/test_meta_schedule_runner.py
new file mode 100644
index 000000000000..3c8aee0c6d58
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_runner.py
@@ -0,0 +1,571 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Test Meta Schedule Runner """
+
+import itertools
+import sys
+import time
+from typing import Any, List
+
+import numpy as np
+import pytest
+
+import tvm
+from tvm import tir
+from tvm._ffi import register_func
+from tvm.meta_schedule.arg_info import TensorInfo
+from tvm.meta_schedule.builder import BuilderInput, LocalBuilder
+from tvm.meta_schedule.runner import (
+    EvaluatorConfig,
+    PyRunner,
+    RPCConfig,
+    RPCRunner,
+    RunnerFuture,
+    RunnerInput,
+)
+from tvm.meta_schedule.runner.rpc_runner import (
+    default_alloc_argument as rpc_default_alloc_argument,
+)
+from tvm.meta_schedule.testing import LocalRPC
+from tvm.meta_schedule.utils import get_global_func_with_default_on_worker
+from tvm.rpc import RPCSession
+from tvm.runtime import Device, Module
+from tvm.script import ty
+from tvm.target import Target
+import tvm.testing
+from tvm.tir import FloatImm
+
+MATMUL_N = 16
+MATMUL_M = 32
+
+# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking
+
+
+@tvm.script.tir
+class MatmulModule:
+    def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:  # pylint: disable=no-self-argument
+        tir.func_attr({"global_symbol": "main", "tir.noalias": True})
+        A = tir.match_buffer(a, (16, 16), "float32")
+        B = tir.match_buffer(b, (16, 16), "float32")
+        C = tir.match_buffer(c, (16, 16), "float32")
+        with tir.block([16, 16, tir.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]:
+            with tir.init():
+                C[vi, vj] = 0.0
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+
+@tvm.script.tir
+class MatmulReluModule:
+    def main(a: ty.handle, b: ty.handle, d: ty.handle) -> None:  # pylint: disable=no-self-argument
+        tir.func_attr({"global_symbol": "main", "tir.noalias": True})
+        A = tir.match_buffer(a, (16, 16), "float32")
+        B = tir.match_buffer(b, (16, 16), "float32")
+        D = tir.match_buffer(d, (16, 16), "float32")
+        C = tir.alloc_buffer((16, 16), "float32")
+        with tir.block([16, 16, tir.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]:
+            with tir.init():
+                C[vi, vj] = 0.0
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+        with tir.block([16, 16], "relu") as [vi, vj]:
+            D[vi, vj] = tir.max(C[vi, vj], 0.0)
+
+
+@tvm.script.tir
+class BatchMatmulModule:
+    def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:  # pylint: disable=no-self-argument
+        tir.func_attr({"global_symbol": "main", "tir.noalias": True})
+        A = tir.match_buffer(a, [16, 32, 32])
+        B = tir.match_buffer(b, [16, 32, 32])
+        C = tir.match_buffer(c, [16, 32, 32])
+        with tir.block([16, 32, 32, tir.reduce_axis(0, 32)], "update") as [vn, vi, vj, vk]:
+            with tir.init():
+                C[vn, vi, vj] = 0.0
+            C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk]
+
+
+@tvm.script.tir
+class AddModule:
+    def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:  # pylint: disable=no-self-argument
+        tir.func_attr({"global_symbol": "main", "tir.noalias": True})
+        A = tir.match_buffer(a, [32], "float32")
+        B = tir.match_buffer(b, [32], "float32")
+        C = tir.match_buffer(c, [32], "float32")
+        with tir.block([32], "add") as [vi]:
+            C[vi] = A[vi] + B[vi]
+
+
+# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring
+
+
+def _clean_build(artifact_path: str) -> None:
+    f_clean_build = get_global_func_with_default_on_worker("meta_schedule.remove_build_dir", None)
+    if f_clean_build is not None:
+        f_clean_build(artifact_path)
+    else:
+        raise RuntimeError("Unable to find remove_build_dir function.")
+
+
+def test_meta_schedule_rpc_single_run():
+    """Test meta schedule rpc runner for a single run"""
+    # Build the module
+    mod = MatmulModule()
+    builder = LocalBuilder()
+    (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))])
+    assert builder_result.artifact_path is not None
+    assert builder_result.error_msg is None
+
+    runner_input = RunnerInput(
+        builder_result.artifact_path,
+        "llvm",
+        [
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+        ],
+    )
+
+    with LocalRPC() as rpc:
+        rpc_config = RPCConfig(
+            tracker_host=rpc.tracker_host,
+            tracker_port=rpc.tracker_port,
+            tracker_key=rpc.tracker_key,
+            session_priority=1,
+            session_timeout_sec=100,
+        )
+        evaluator_config = EvaluatorConfig(
+            number=1,
+            repeat=1,
+            min_repeat_ms=0,
+            enable_cpu_cache_flush=False,
+        )
+        runner = RPCRunner(rpc_config, evaluator_config)
+        # Run the module
+        (runner_future,) = runner.run([runner_input])
+        runner_result = runner_future.result()
+    assert runner_result.error_msg is None
+    for result in runner_result.run_secs:
+        if isinstance(result, FloatImm):
+            result = result.value
+        assert isinstance(result, float)
+        assert result >= 0.0
+    _clean_build(builder_result.artifact_path)
+
+
+def test_meta_schedule_rpc_multiple_runs():
+    """Test meta schedule rpc runner for multiple runs"""
+    # Build the module
+    mods = [
+        MatmulModule(),
+        MatmulReluModule(),
+        BatchMatmulModule(),
+    ]
+    builder = LocalBuilder()
+    builder_inputs = [BuilderInput(mod, Target("llvm")) for mod in mods]
+    builder_results = builder.build(builder_inputs)
+    for builder_result in builder_results:
+        assert builder_result.artifact_path is not None
+        assert builder_result.error_msg is None
+
+    args_infos = [
+        [
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+        ],
+        [
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+        ],
+        [
+            TensorInfo("float32", [16, MATMUL_M, MATMUL_M]),
+            TensorInfo("float32", [16, MATMUL_M, MATMUL_M]),
+            TensorInfo("float32", [16, MATMUL_M, MATMUL_M]),
+        ],
+    ]
+
+    runner_inputs = [
+        RunnerInput(builder_results[i].artifact_path, "llvm", args_infos[i])
+        for i in range(len(mods))
+    ]
+
+    with LocalRPC() as rpc:
+        rpc_config = RPCConfig(
+            tracker_host=rpc.tracker_host,
+            tracker_port=rpc.tracker_port,
+            tracker_key=rpc.tracker_key,
+            session_priority=1,
+            session_timeout_sec=100,
+        )
+        evaluator_config = EvaluatorConfig(
+            number=1,
+            repeat=1,
+            min_repeat_ms=0,
+            enable_cpu_cache_flush=False,
+        )
+        runner = RPCRunner(rpc_config, evaluator_config)
+        # Run the module
+        runner_futures = runner.run(runner_inputs)
+        runner_results = [runner_future.result() for runner_future in runner_futures]
+
+    for runner_result in runner_results:
+        assert runner_result.error_msg is None
+        for result in runner_result.run_secs:
+            if isinstance(result, FloatImm):
+                result = result.value
+            assert isinstance(result, float)
+            assert result >= 0.0
+
+    for builder_result in builder_results:
+        _clean_build(builder_result.artifact_path)
+
+
+def test_meta_schedule_py_runner():
+    """Test meta schedule PyRunner"""
+
+    class TestRunner(PyRunner):
+        def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
+            raise ValueError("TestRunner")
+
+    runner = TestRunner()
+    with pytest.raises(ValueError, match="TestRunner"):
+        runner.run([])
+
+
+def test_meta_schedule_rpc_runner_time_out():
+    """Test meta schedule RPC Runner time out"""
+
+    def initializer():
+        @register_func("meta_schedule.runner.test_time_out")
+        def timeout_session_creator(  # pylint: disable=unused-variable
+            rpc_config: RPCConfig,  # pylint: disable=unused-argument
+        ) -> RPCSession:
+            time.sleep(2)
+
+    runner_input = RunnerInput(
+        "test",
+        "llvm",
+        [
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+        ],
+    )
+
+    with LocalRPC() as rpc:
+        rpc_config = RPCConfig(
+            tracker_host=rpc.tracker_host,
+            tracker_port=rpc.tracker_port,
+            tracker_key=rpc.tracker_key,
+            session_priority=1,
+            session_timeout_sec=1,
+        )
+        evaluator_config = EvaluatorConfig(
+            number=1,
+            repeat=1,
+            min_repeat_ms=0,
+            enable_cpu_cache_flush=False,
+        )
+        runner = RPCRunner(
+            rpc_config,
+            evaluator_config,
+            initializer=initializer,
+            f_create_session="meta_schedule.runner.test_time_out",
+        )
+        # Run the module
+        (runner_future,) = runner.run([runner_input])
+        runner_result = runner_future.result()
+
+    assert runner_result.error_msg is not None and runner_result.error_msg.startswith(
+        "RPCRunner: Timeout, killed after"
+    )
+    assert runner_result.run_secs is None
+
+
+def test_meta_schedule_rpc_runner_exception():
+    """Test meta schedule RPC Runner exception"""
+
+    def initializer():
+        @register_func("meta_schedule.runner.test_exception")
+        def exception_session_creator(  # pylint: disable=unused-variable
+            rpc_config: RPCConfig,  # pylint: disable=unused-argument
+        ) -> RPCSession:
+            raise Exception("Test")
+
+    runner_input = RunnerInput(
+        "test",
+        "llvm",
+        [
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+        ],
+    )
+
+    with LocalRPC() as rpc:
+        rpc_config = RPCConfig(
+            tracker_host=rpc.tracker_host,
+            tracker_port=rpc.tracker_port,
+            tracker_key=rpc.tracker_key,
+            session_priority=1,
+            session_timeout_sec=100,
+        )
+        evaluator_config = EvaluatorConfig(
+            number=1,
+            repeat=1,
+            min_repeat_ms=0,
+            enable_cpu_cache_flush=False,
+        )
+        runner = RPCRunner(
+            rpc_config,
+            evaluator_config,
+            initializer=initializer,
+            f_create_session="meta_schedule.runner.test_exception",
+        )
+        (runner_future,) = runner.run([runner_input])
+        runner_result = runner_future.result()
+
+    assert runner_result.error_msg is not None and runner_result.error_msg.startswith(
+        "RPCRunner: An exception occurred\n"
+    )
+    assert runner_result.run_secs is None
+
+
+def test_meta_schedule_runner_matmul_test():
+    """Test meta schedule runner with add module"""
+
+    def _check_correct_matmul(
+        args_before: List[np.ndarray],
+        args_after: List[np.ndarray],
+    ) -> None:
+        a_before, b_before, c_before = args_before
+        a_after, b_after, c_after = args_after
+        c_before = np.matmul(a_before, b_before)
+        assert (a_before == a_after).all()
+        assert (b_before == b_after).all()
+        tvm.testing.assert_allclose(c_before, c_after, rtol=1e-5)
+
+    def test_alloc_argument(
+        session: RPCSession,
+        device: Device,
+        args_info: Any,
+        alloc_repeat: int,
+    ) -> List[Any]:
+        global repeated_args_before  # pylint: disable=global-variable-undefined, invalid-name
+        repeated_args_before = []  # type: ignore
+        repeated_args = rpc_default_alloc_argument(session, device, args_info, alloc_repeat)
+        for args in repeated_args:
+            repeated_args_before.append([arg.numpy() for arg in args])  # type: ignore
+        return repeated_args
+
+    def test_run_evaluator(
+        session: RPCSession,  # pylint: disable=unused-argument
+        rt_mod: Module,
+        device: Device,
+        evaluator_config: EvaluatorConfig,
+        repeated_args: List[Any],
+    ) -> List[float]:
+        global repeated_args_before  # pylint: disable=global-variable-undefined, invalid-name
+        repeated_args_after = []
+        evaluator = rt_mod.time_evaluator(
+            func_name=rt_mod.entry_name,
+            dev=device,
+            number=evaluator_config.number,
+            repeat=evaluator_config.repeat,
+            min_repeat_ms=evaluator_config.min_repeat_ms,
+            f_preproc="cache_flush_cpu_non_first_arg"
+            if evaluator_config.enable_cpu_cache_flush
+            else "",
+        )
+        repeated_costs: List[List[float]] = []
+        for args in repeated_args:
+            device.sync()
+            profile_result = evaluator(*args)
+            repeated_costs.append(profile_result.results)
+            repeated_args_after.append([arg.numpy() for arg in args])
+        costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
+        for args_before, args_after in zip(
+            repeated_args_before,  # type: ignore
+            repeated_args_after,
+        ):
+            _check_correct_matmul(args_before, args_after)
+        del repeated_args_before  # type: ignore
+        return costs
+
+    # Build the module
+    mod = MatmulModule()
+    builder = LocalBuilder()
+    (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))])
+    assert builder_result.artifact_path is not None
+    assert builder_result.error_msg is None
+
+    runner_input = RunnerInput(
+        builder_result.artifact_path,
+        "llvm",
+        [
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
+        ],
+    )
+
+    with LocalRPC() as rpc:
+        rpc_config = RPCConfig(
+            tracker_host=rpc.tracker_host,
+            tracker_port=rpc.tracker_port,
+            tracker_key=rpc.tracker_key,
+            session_priority=1,
+            session_timeout_sec=100,
+        )
+        evaluator_config = EvaluatorConfig(
+            number=1,
+            repeat=1,
+            min_repeat_ms=0,
+            enable_cpu_cache_flush=False,
+        )
+        runner = RPCRunner(
+            rpc_config,
+            evaluator_config,
+            f_alloc_argument=test_alloc_argument,
+            f_run_evaluator=test_run_evaluator,
+        )
+        # Run the module
+        (runner_future,) = runner.run([runner_input])
+        runner_result = runner_future.result()
+    assert runner_result.error_msg is None
+    for result in runner_result.run_secs:
+        if isinstance(result, FloatImm):
+            result = result.value
+        assert isinstance(result, float)
+        assert result >= 0.0
+    _clean_build(builder_result.artifact_path)
+
+
+def test_meta_schedule_runner_add_test():
+    """Test meta schedule runner with add module"""
+
+    def _check_correct_add(args_before: List[np.ndarray], args_after: List[np.ndarray]) -> None:
+        a_before, b_before, c_before = args_before
+        a_after, b_after, c_after = args_after
+        c_before = a_before + b_before
+        assert (a_before == a_after).all()
+        assert (b_before == b_after).all()
+        assert (c_before == c_after).all()
+
+    def test_alloc_argument(
+        session: RPCSession,
+        device: Device,
+        args_info: Any,
+        alloc_repeat: int,
+    ) -> List[Any]:
+        global repeated_args_before  # pylint: disable=global-variable-undefined, invalid-name
+        repeated_args_before = []  # type: ignore
+        repeated_args = rpc_default_alloc_argument(
+            session,
+            device,
+            args_info,
+            alloc_repeat,
+        )
+        for args in repeated_args:
+            repeated_args_before.append([arg.numpy() for arg in args])  # type: ignore
+        return repeated_args
+
+    def test_run_evaluator(
+        session: RPCSession,  # pylint: disable=unused-argument
+        rt_mod: Module,
+        device: Device,
+        evaluator_config: EvaluatorConfig,
+        repeated_args: List[Any],
+    ) -> List[float]:
+        global repeated_args_before  # pylint: disable=global-variable-undefined, invalid-name
+        repeated_args_after = []
+        evaluator = rt_mod.time_evaluator(
+            func_name=rt_mod.entry_name,
+            dev=device,
+            number=evaluator_config.number,
+            repeat=evaluator_config.repeat,
+            min_repeat_ms=evaluator_config.min_repeat_ms,
+            f_preproc="cache_flush_cpu_non_first_arg"
+            if evaluator_config.enable_cpu_cache_flush
+            else "",
+        )
+        repeated_costs: List[List[float]] = []
+        for args in repeated_args:
+            device.sync()
+            profile_result = evaluator(*args)
+            repeated_costs.append(profile_result.results)
+            repeated_args_after.append([arg.numpy() for arg in args])
+        costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
+        for args_before, args_after in zip(
+            repeated_args_before,  # type: ignore
+            repeated_args_after,
+        ):
+            _check_correct_add(args_before, args_after)
+        del repeated_args_before  # type: ignore
+        return costs
+
+    # Build the module
+    mod = AddModule()
+    builder = LocalBuilder()
+    (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))])
+    assert builder_result.artifact_path is not None
+    assert builder_result.error_msg is None
+
+    runner_input = RunnerInput(
+        builder_result.artifact_path,
+        "llvm",
+        [
+            TensorInfo("float32", [MATMUL_M]),
+            TensorInfo("float32", [MATMUL_M]),
+            TensorInfo("float32", [MATMUL_M]),
+        ],
+    )
+
+    with LocalRPC() as rpc:
+        rpc_config = RPCConfig(
+            tracker_host=rpc.tracker_host,
+            tracker_port=rpc.tracker_port,
+            tracker_key=rpc.tracker_key,
+            session_priority=1,
+            session_timeout_sec=100,
+        )
+        evaluator_config = EvaluatorConfig(
+            number=1,
+            repeat=1,
+            min_repeat_ms=0,
+            enable_cpu_cache_flush=False,
+        )
+        runner = RPCRunner(
+            rpc_config,
+            evaluator_config,
+            f_alloc_argument=test_alloc_argument,
+            f_run_evaluator=test_run_evaluator,
+        )
+        # Run the module
+        (runner_future,) = runner.run([runner_input])
+        runner_result = runner_future.result()
+    assert runner_result.error_msg is None
+    for result in runner_result.run_secs:
+        if isinstance(result, FloatImm):
+            result = result.value
+        assert isinstance(result, float)
+        assert result >= 0.0
+    _clean_build(builder_result.artifact_path)
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))