Skip to content

Commit

Permalink
[AutoScheduler] Support layout rewrite for whole networks (apache#6987)
Browse files Browse the repository at this point in the history
* [AutoScheduler] Add layout rewrite pass in relay

* fix

* fix lint

* fix attrs

* trigger CI

* Apply suggestions from code review

* trigger CI

* Update python/tvm/auto_scheduler/relay_integration.py

* Update python/tvm/auto_scheduler/relay_integration.py

* Update python/tvm/auto_scheduler/compute_dag.py

* Trigger CI

* Apply suggestions from code review
  • Loading branch information
merrymercy authored and Trevor Morris committed Dec 3, 2020
1 parent 48230e7 commit 0a6ff8d
Show file tree
Hide file tree
Showing 26 changed files with 751 additions and 71 deletions.
7 changes: 7 additions & 0 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ class PassContext : public ObjectRef {
*/
TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const;

/*!
* \brief Check whether a pass is enabled.
* \param info The pass information.
* \return true if the pass is enabled. Otherwise, false.
*/
TVM_DLL bool PassEnabled(const PassInfo& info) const;

/*!
* \brief Register a valid configuration option and its ValueType for validation.
*
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
tvm::String data_layout;
tvm::String kernel_layout;
tvm::String out_layout;
std::string auto_scheduler_rewritten_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,20 @@ struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
}
};

/*! \brief Attributes for AutoSchedulerLayoutTransform operator */
struct AutoSchedulerLayoutTransformAttrs
: public tvm::AttrsNode<AutoSchedulerLayoutTransformAttrs> {
std::string src_layout;
std::string dst_layout;

TVM_DECLARE_ATTRS(AutoSchedulerLayoutTransformAttrs,
"relay.attrs.AutoSchedulerLayoutTransformAttrs") {
TVM_ATTR_FIELD(src_layout).describe("The source layout of the tensor. (e.g. 1N32C112H112W)");
TVM_ATTR_FIELD(dst_layout)
.describe("The destination layout of the tensor. (e.g. 1N2C112H112W16c)");
}
};

/*! \brief Attributes for ShapeOf operator */
struct ShapeOfAttrs : public tvm::AttrsNode<ShapeOfAttrs> {
DataType dtype;
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ TVM_DLL Pass FoldConstant();
*/
TVM_DLL Pass FuseOps(int fuse_opt_level = -1);

/*!
* \brief The inverse operation of FuseOps. It transforms a fused program returned by
* FuseOps into the program before FuseOps. (i.e. x == DefuseOps(FuseOps(x)))
*
* \return The pass.
*/
TVM_DLL Pass DefuseOps();

/*!
* \brief Rewrite the annotated program.
*
Expand Down Expand Up @@ -315,6 +323,12 @@ TVM_DLL Pass CanonicalizeOps();
*/
TVM_DLL Pass AlterOpLayout();

/*!
* \brief Do layout rewrite according to the tile structure created by auto-scheduler.
* \return The pass
*/
TVM_DLL Pass AutoSchedulerLayoutRewrite();

/*!
* \brief Given a dest layout, this pass transforms the expr such that most of the ops input data
* layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one
Expand Down
68 changes: 68 additions & 0 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,74 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
name, tag);
}

/*! \brief Utility function for auto_scheduler_layout_transform */
inline void parse_auto_scheduler_layout(const String& layout, Array<PrimExpr>* shape,
std::vector<std::string>* axes) {
int32_t factor = 0;
std::string axis = "";
for (char c : std::string(layout)) {
if (c >= 'A' && c <= 'z') {
axis += c;
if (factor != 0) {
shape->push_back(factor);
factor = 0;
}
} else if (c >= '0' && c <= '9') {
factor = factor * 10 + c - '0';
if (!axis.empty()) {
axes->push_back(axis);
axis = "";
}
} else {
LOG(FATAL) << "Invalid layout " << layout;
}
}
if (!axis.empty()) {
axes->push_back(axis);
}
}

/*!
* \brief Transform the auto-scheduler generated layout according to
* \p src_layout and \p dst_layout
* \param src the source input.
* \param src_layout the source layout.
* \param dst_layout the destination layout.
* \param name output tensor name.
* \param tag output tensor tag.
* \return A tensor with shape in \p dst_layout
*/
inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& src_layout,
const String& dst_layout,
const String name = "T_auto_scheduler_layout_trans",
const String tag = kInjective) {
Array<PrimExpr> src_shape;
std::vector<std::string> src_axes;
Array<PrimExpr> dst_shape;
std::vector<std::string> dst_axes;

parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes);
parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes);
return compute(
dst_shape,
[&](const Array<Var>& dst_indices) {
Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
Array<PrimExpr> src_indices;
for (const std::string& src_axis : src_axes) {
PrimExpr src_index = 0;
CHECK_EQ(dst_indices_expr.size(), dst_axes.size());
for (size_t i = 0; i < dst_axes.size(); ++i) {
if (dst_axes[i] == src_axis) {
src_index = src_index * dst_shape[i] + dst_indices_expr[i];
}
}
src_indices.push_back(src_index);
}
return src(src_indices);
},
name, tag);
}

/*!
* \brief Get the shape of input tensor.
* \param src the input tensor.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
LocalRPCMeasureContext,
)
from .measure_record import RecordToFile, RecordReader, load_best, load_records, save_records
from .relay_integration import extract_tasks
from .relay_integration import extract_tasks, remove_index_check, rewrite_compute_body
from .search_task import SearchTask
from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
from .task_scheduler import TaskScheduler
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,23 @@ def infer_bound_from_state(self, state):
updated_state.stage_id_map[k] = v
return updated_state

def rewrite_layout_from_state(self, state):
"""
Rewrite the layout of the DAG according to the history transform steps of a state.
Parameters
----------
state : Union[State, StateObject]
The state from which we get transform steps.
Returns
-------
updated_dag : ComputeDAG
The compute dag with rewritten layout.
"""
state_obj = state if isinstance(state, StateObject) else state.state_object
return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state_obj)

def hash_key(self):
"""Return the hash key of this compute DAG.
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,9 @@ def _timed_func(inp_serialized, build_func, verbose):
args = []

try:
sch, args = task.compute_dag.apply_steps_from_state(inp.state, layout_rewrite=True)
sch, args = task.compute_dag.apply_steps_from_state(
inp.state, layout_rewrite=ComputeDAG.RewriteForPreTransformed
)
# pylint: disable=broad-except
except Exception:
error_no = MeasureErrorNo.INSTANTIATION_ERROR
Expand Down
103 changes: 96 additions & 7 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@
"""

import logging
import json
import threading

import tvm
from tvm import autotvm, te, transform
from tvm.te.tensor import ComputeOp, PlaceholderOp
from tvm.runtime import convert_to_object
from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
from tvm.tir import expr as _expr
from . import _ffi_api
from .compute_dag import ComputeDAG
from .dispatcher import DispatchContext
from .search_task import SearchTask
Expand All @@ -46,7 +50,11 @@ def call_all_topi_funcs(mod, params, target):
old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent
autotvm.GLOBAL_SCOPE.silent = True

with transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}):
with transform.PassContext(
opt_level=3,
config={"relay.backend.use_auto_scheduler": True},
disabled_pass={"AutoSchedulerLayoutRewrite"},
):
opt_mod, _ = relay.optimize(mod, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
grc.codegen(opt_mod["main"])
Expand Down Expand Up @@ -158,6 +166,20 @@ def add_workload_key(self, workload_key, ccache_key):
self.wkl_key_to_ccache_key[workload_key] = ccache_key


@tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite")
def enter_layout_rewrite():
"""Enter layout rewrite tracing environment"""
env = TracingEnvironment(TracingMode.PREPARE_LAYOUT_REWRITE)
env.__enter__()


@tvm._ffi.register_func("auto_scheduler.exit_layout_rewrite")
def exit_layout_rewrite():
"""Exit layout rewrite tracing environment"""
env = TracingEnvironment.current
env.__exit__(None, None, None)


def traverse_to_get_io_tensors(outs):
"""Traverse from a list of output tensors to get both input and output tensors
Expand Down Expand Up @@ -230,11 +252,13 @@ def auto_schedule_topi(outs, has_complex_op):
key = register_workload_tensors(dag.hash_key(), io_tensors)

# only enable layout rewrite for cpu backend
enable_layout_rewrite = "cpu" in tvm.target.Target.current().keys
target = tvm.target.Target.current()
enable_layout_rewrite = "cpu" in target.keys

env = TracingEnvironment.current
if env is None: # in the final build mode
state = DispatchContext.current.query(tvm.target.Target.current(), key, has_complex_op, dag)
if env is None:
# in the final build mode
state = DispatchContext.current.query(target, key, has_complex_op, dag)
if state is None:
return None

Expand All @@ -247,9 +271,74 @@ def auto_schedule_topi(outs, has_complex_op):
env.add_workload_key(key, ccache_key)
schedule = te.create_schedule([x.op for x in outs])
elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
# todo(merrymercy, minminsun): port layout rewrite
raise NotImplementedError
# in prepare_layout_rewrite mode
if enable_layout_rewrite and has_layout_free:
dispatch_ctx = DispatchContext.current
state = dispatch_ctx.query(target, key, has_complex_op, dag)
if state is None:
return None

# rewrite the layout and update the context for the new dag
dag = ComputeDAG(outs)
new_dag = dag.rewrite_layout_from_state(state)
new_key = json.dumps((new_dag.hash_key(),))
if new_key != key:
dispatch_ctx.update(target, new_key, state)
return te.create_schedule([x.op for x in outs])
else:
raise ValueError("Invalid tracing mode: " + env.tracing_mode)

return schedule


def tensor_no_check_call(self, *indices):
"""An indexing function without any check.
This is the same as `tvm.te.Tensor::__call__` except that the safety
check is removed.
"""
indices = convert_to_object(indices)
args = []
for x in indices:
if isinstance(x, _expr.PrimExpr):
args.append(x)
elif isinstance(x, _expr.IterVar):
args.append(x.var)
else:
raise ValueError("The indices must be expression")

return _expr.ProducerLoad(self, args)


def remove_index_check(tensor):
"""Remove the safety check in the indexing function for a tensor.
This is done by monkey patching its indexing function.
After removing the check, we are allowed to create a
temporary wrong IR and fix it later in other places.
Parameters
----------
tensor: Tensor
The tensor to remove index check.
"""
# Monkey patch the indexing function
tensor.__call__ = tensor_no_check_call.__get__(tensor, Tensor)


def rewrite_compute_body(compute_tensor, new_layout):
"""Rewrite the body of a ComputeOp according to a new layout of a placeholder"""
op = compute_tensor.op

# Get layout free placeholders
layout_free_placeholders = op.attrs["layout_free_placeholders"]
assert len(layout_free_placeholders) == 1, "Only support one layout free placeholder"
placeholder_op = layout_free_placeholders[0].op

# Rewrite the index expression in body
body = []
for b in op.body:
body.append(_ffi_api.RewriteIndexForNewLayout(placeholder_op, new_layout, b))
op_node = tvm.te._ffi_api.ComputeOp(op.name, op.tag, op.attrs, op.axis, body)

num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def compute_strided_set(attrs, inputs, output_type):
# layout_transform
_reg.register_injective_schedule("layout_transform")
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
_reg.register_injective_schedule("auto_scheduler_layout_transform")
_reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE)

# argwhere
@_reg.register_compute("argwhere")
Expand Down
15 changes: 13 additions & 2 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging

import re
from tvm import topi
from tvm import topi, _ffi
from tvm.topi.utils import get_const_int, get_const_float, get_const_tuple, get_float_tuple
from tvm.target import generic_func, override_native_generic_func
from .. import op as _op
Expand Down Expand Up @@ -166,9 +166,17 @@ def schedule_bitpack(attrs, outs, target):
return topi.generic.schedule_bitpack(outs)


get_auto_scheduler_rewritten_layout = _ffi.get_global_func(
"relay.attrs.get_auto_scheduler_rewritten_layout"
)

# conv2d
def wrap_compute_conv2d(
topi_compute, need_data_layout=False, need_out_layout=False, has_groups=False
topi_compute,
need_data_layout=False,
need_out_layout=False,
has_groups=False,
need_auto_scheduler_layout=False,
):
"""Wrap conv2d topi compute"""

Expand All @@ -179,6 +187,7 @@ def _compute_conv2d(attrs, inputs, out_type):
data_layout = attrs.get_str("data_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.out_dtype
auto_scheduler_rewritten_layout = get_auto_scheduler_rewritten_layout(attrs)
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
args = [inputs[0], inputs[1], strides, padding, dilation]
if has_groups:
Expand All @@ -188,6 +197,8 @@ def _compute_conv2d(attrs, inputs, out_type):
if need_out_layout:
args.append(out_layout)
args.append(out_dtype)
if need_auto_scheduler_layout:
args.append(auto_scheduler_rewritten_layout)
return [topi_compute(*args)]

return _compute_conv2d
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
logger.warning("For x86 target, NCHW layout is recommended for conv2d.")
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nhwc),
wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True),
wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc),
name="conv2d_nhwc.x86",
)
Expand Down
Loading

0 comments on commit 0a6ff8d

Please sign in to comment.