Skip to content

Commit

Permalink
[MetaSchedule] Update scripts for subgraph tuning (apache#10501)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored and pfk-beta committed Apr 11, 2022
1 parent b2e6311 commit 284bcd7
Show file tree
Hide file tree
Showing 10 changed files with 310 additions and 27 deletions.
10 changes: 7 additions & 3 deletions python/tvm/auto_scheduler/workload_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@
When we need the dag, we decode the string and call the function, which will return the dag.
"""

import json
import logging
import pickle
import json

import tvm._ffi
from tvm.runtime._ffi_node_api import LoadJSON, SaveJSON
from .utils import serialize_args, deserialize_args, get_func_name

from .utils import deserialize_args, get_func_name, serialize_args

logger = logging.getLogger("auto_scheduler")

Expand Down Expand Up @@ -194,7 +195,10 @@ def workload_key_to_tensors(workload_key):
assert callable(value)

args = deserialize_args(workload[1:])
return value(*args)
result = value(*args)
if isinstance(result, tuple):
result = list(result)
return result


def serialize_workload_registry_entry(workload_key):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/runner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class EvaluatorConfig(NamedTuple):

number: int = 3
repeat: int = 1
min_repeat_ms: int = 40
min_repeat_ms: int = 100
enable_cpu_cache_flush: bool = False

@staticmethod
Expand Down
9 changes: 4 additions & 5 deletions python/tvm/meta_schedule/runner/rpc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
# under the License.
"""RPC Runner"""
import concurrent.futures
from contextlib import contextmanager
import logging
import os.path as osp
from contextlib import contextmanager
from typing import Callable, List, Optional, Union

from tvm.contrib.popen_pool import PopenPoolExecutor
Expand All @@ -31,15 +31,14 @@
get_global_func_with_default_on_worker,
)
from .config import EvaluatorConfig, RPCConfig
from .runner import PyRunner, RunnerFuture, PyRunnerFuture, RunnerInput, RunnerResult
from .runner import PyRunner, PyRunnerFuture, RunnerFuture, RunnerInput, RunnerResult
from .utils import (
T_ARGUMENT_LIST,
T_ARG_INFO_JSON_OBJ_LIST,
T_ARGUMENT_LIST,
alloc_argument_common,
run_evaluator_common,
)


logger = logging.getLogger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -118,7 +117,7 @@ def done(self) -> bool:
def result(self) -> RunnerResult:
try:
run_secs: List[float] = self.future.result()
except TimeoutError as exception:
except TimeoutError:
return RunnerResult(
None,
error_msg=f"RPCRunner: Timeout, killed after {self.timeout_sec} seconds",
Expand Down
137 changes: 137 additions & 0 deletions python/tvm/meta_schedule/testing/run_subgraph_auto_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitatios
# under the License.
# pylint: disable=missing-docstring
import argparse
import os

import tvm
from tvm import auto_scheduler
from tvm.meta_schedule.runner import RPCConfig
from tvm.meta_schedule.testing.te_workload import CONFIGS


def _parse_args():
args = argparse.ArgumentParser()
args.add_argument(
"--workload",
type=str,
required=True,
)
args.add_argument(
"--target",
type=str,
required=True,
)
args.add_argument(
"--num-trials",
type=int,
required=True,
)
args.add_argument(
"--rpc-host",
type=str,
required=True,
)
args.add_argument(
"--rpc-port",
type=int,
required=True,
)
args.add_argument(
"--rpc-key",
type=str,
required=True,
)
args.add_argument(
"--log-dir",
type=str,
required=True,
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
parsed.rpc_workers = RPCConfig(
tracker_host=parsed.rpc_host,
tracker_port=parsed.rpc_port,
tracker_key=parsed.rpc_key,
session_timeout_sec=30,
).count_num_servers(allow_missing=True)
return parsed


ARGS = _parse_args()


def main():
log_file = os.path.join(ARGS.log_dir, f"{ARGS.workload}.json")
workload_func, params = CONFIGS[ARGS.workload]
params = params[0] # type: ignore
workload_func = auto_scheduler.register_workload(workload_func)

if ARGS.target.kind.name == "llvm":
hardware_params = auto_scheduler.HardwareParams(
num_cores=int(ARGS.target.attrs["num-cores"]),
target=ARGS.target,
)
elif ARGS.target.kind.name == "cuda":
hardware_params = auto_scheduler.HardwareParams(
num_cores=-1,
vector_unit_bytes=16,
cache_line_bytes=64,
max_shared_memory_per_block=int(ARGS.target.attrs["max_shared_memory_per_block"]),
max_threads_per_block=int(ARGS.target.attrs["max_threads_per_block"]),
max_vthread_extent=8,
warp_size=32,
)
else:
raise NotImplementedError(f"Unsupported target {ARGS.target}")
task = auto_scheduler.SearchTask(
func=workload_func,
args=params,
target=ARGS.target,
hardware_params=hardware_params,
)
runner = auto_scheduler.RPCRunner(
key=ARGS.rpc_key,
host=ARGS.rpc_host,
port=ARGS.rpc_port,
n_parallel=ARGS.rpc_workers,
number=3,
repeat=1,
min_repeat_ms=100,
enable_cpu_cache_flush=False,
)

# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=ARGS.num_trials,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
verbose=2,
runner=runner,
)
print("Running AutoTuning:")
task.tune(tune_option)
print("History Best:")
print(task.print_best(log_file))
sch, args = task.apply_best(log_file)
print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))


if __name__ == "__main__":
main()
120 changes: 120 additions & 0 deletions python/tvm/meta_schedule/testing/run_subgraph_meta_schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
import argparse
import logging
from os import cpu_count
from typing import Optional

import tvm
from tvm import meta_schedule as ms
from tvm import tir
from tvm.meta_schedule.testing.te_workload import create_te_workload


def _parse_args():
args = argparse.ArgumentParser()
args.add_argument(
"--workload",
type=str,
required=True,
)
args.add_argument(
"--target",
type=str,
required=True,
)
args.add_argument(
"--num-trials",
type=int,
required=True,
)
args.add_argument(
"--work-dir",
type=str,
required=True,
)
args.add_argument(
"--rpc-host",
type=str,
required=True,
)
args.add_argument(
"--rpc-port",
type=int,
required=True,
)
args.add_argument(
"--rpc-key",
type=str,
required=True,
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu":
parsed.alloc_repeat = 3
else:
parsed.alloc_repeat = 1
parsed.rpc_config = ms.runner.RPCConfig(
tracker_host=parsed.rpc_host,
tracker_port=parsed.rpc_port,
tracker_key=parsed.rpc_key,
session_timeout_sec=30,
)
parsed.rpc_workers = parsed.rpc_config.count_num_servers(allow_missing=False)
return parsed


logging.basicConfig()
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
ARGS = _parse_args()


def main():
runner = ms.runner.RPCRunner(
rpc_config=ARGS.rpc_config,
evaluator_config=ms.runner.EvaluatorConfig(
number=3,
repeat=1,
min_repeat_ms=100,
enable_cpu_cache_flush=False,
),
alloc_repeat=ARGS.alloc_repeat,
max_workers=ARGS.rpc_workers,
)
sch: Optional[tir.Schedule] = ms.tune_tir(
mod=create_te_workload(ARGS.workload, 0),
target=ARGS.target,
config=ms.EvolutionarySearchConfig(
num_trials_per_iter=64,
num_trials_total=ARGS.num_trials,
init_min_unmeasured=50,
),
runner=runner, # type: ignore
task_name=ARGS.workload,
work_dir=ARGS.work_dir,
num_threads=cpu_count(),
)
if sch is None:
print("No valid schedule found!")
else:
print(sch.mod.script())
print(sch.trace)


if __name__ == "__main__":
main()
3 changes: 1 addition & 2 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ class VerifyGPUCodeNode : public PostprocNode {
ICHECK(context->target.defined());
Target target = context->target.value();
this->target_constraints_ = Map<String, PrimExpr>{
{"max_shared_memory_per_block", Extract(target, "shared_memory_per_block")},
{"max_local_memory_per_block", Extract(target, "registers_per_block")},
{"max_shared_memory_per_block", Extract(target, "max_shared_memory_per_block")},
{"max_threads_per_block", Extract(target, "max_threads_per_block")},
{"max_vthread", Integer(8)},
{"max_vector_bytes", Integer(16)}};
Expand Down
29 changes: 26 additions & 3 deletions src/target/tag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,38 @@ Target TargetTag::AddTag(String name, Map<String, ObjectRef> config, bool overri

/********** Register Target tags **********/

TVM_REGISTER_TARGET_TAG("raspberry-pi/4b-aarch64")
.set_config({{"kind", String("llvm")},
{"mtriple", String("aarch64-linux-gnu")},
{"mcpu", String("cortex-a72")},
{"mattr", Array<String>{"+neon"}},
{"num-cores", Integer(4)},
{"host", Map<String, ObjectRef>{{"kind", String("llvm")},
{"mtriple", String("aarch64-linux-gnu")},
{"mcpu", String("cortex-a72")},
{"mattr", Array<String>{"+neon"}},
{"num-cores", Integer(4)}}}});

TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier")
.set_config({{"kind", String("cuda")},
{"arch", String("sm_72")},
{"max_shared_memory_per_block", Integer(49152)},
{"max_threads_per_block", Integer(1024)},
{"thread_warp_size", Integer(32)},
{"registers_per_block", Integer(65536)},
{"host", Map<String, ObjectRef>{{"kind", String("llvm")},
{"mtriple", String("aarch64-linux-gnu")},
{"mcpu", String("carmel")},
{"num-cores", Integer(4)}}}});

#define TVM_REGISTER_CUDA_TAG(Name, Arch, SharedMem, RegPerBlock) \
TVM_REGISTER_TARGET_TAG(Name).set_config({ \
{"kind", String("cuda")}, \
{"arch", String(Arch)}, \
{"shared_memory_per_block", Integer(SharedMem)}, \
{"registers_per_block", Integer(RegPerBlock)}, \
{"max_shared_memory_per_block", Integer(SharedMem)}, \
{"max_threads_per_block", Integer(1024)}, \
{"thread_warp_size", Integer(32)}, \
{"registers_per_block", Integer(RegPerBlock)}, \
});

TVM_REGISTER_CUDA_TAG("nvidia/tesla-k80", "sm_37", 49152, 65536);
Expand Down Expand Up @@ -318,7 +342,6 @@ TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-415m", "sm_21", 49152, 32768);
TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-480m", "sm_20", 49152, 32768);
TVM_REGISTER_CUDA_TAG("nvidia/geforce-710m", "sm_21", 49152, 32768);
TVM_REGISTER_CUDA_TAG("nvidia/geforce-410m", "sm_21", 49152, 32768);
TVM_REGISTER_CUDA_TAG("nvidia/jetson-agx-xavier", "sm_72", 49152, 65536);
TVM_REGISTER_CUDA_TAG("nvidia/jetson-nano", "sm_53", 49152, 32768);
TVM_REGISTER_CUDA_TAG("nvidia/jetson-tx2", "sm_62", 49152, 32768);
TVM_REGISTER_CUDA_TAG("nvidia/jetson-tx1", "sm_53", 49152, 32768);
Expand Down
Loading

0 comments on commit 284bcd7

Please sign in to comment.