Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Nov 29, 2020
1 parent 54b9e51 commit 8129203
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
4 changes: 0 additions & 4 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
std::string auto_scheduler_rewritten_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv2DWinogradAttrs, "relay.attrs.Conv2DWinogradAttrs") {
Expand Down Expand Up @@ -269,9 +268,6 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
"Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
TVM_ATTR_FIELD(auto_scheduler_rewritten_layout)
.set_default("")
.describe("New kernel layout after auto-scheduler's layout rewrite.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,11 @@ def schedule_bitpack(attrs, outs, target):

# conv2d
def wrap_compute_conv2d(
topi_compute, need_data_layout=False, need_out_layout=False, has_groups=False
topi_compute,
need_data_layout=False,
need_out_layout=False,
has_groups=False,
need_auto_scheduler_layout=False,
):
"""Wrap conv2d topi compute"""

Expand All @@ -189,7 +193,8 @@ def _compute_conv2d(attrs, inputs, out_type):
if need_out_layout:
args.append(out_layout)
args.append(out_dtype)
args.append(auto_scheduler_rewritten_layout)
if need_auto_scheduler_layout:
args.append(auto_scheduler_rewritten_layout)
return [topi_compute(*args)]

return _compute_conv2d
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
elif layout == "NHWC":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nhwc),
wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True),
wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc),
name="conv2d_nhwc.x86",
)
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/auto_scheduler_layout_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ class FuncMutator : public ExprMutator {
std::deque<std::string> ori_layouts_queue_;
std::deque<std::string> new_layouts_queue_;

std::vector<std::string> target_ops_{"nn.contrib_conv2d_winograd_without_weight_transform",
"nn.conv2d"};
std::vector<std::string> target_ops_{"nn.conv2d"};
};

Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) {
Expand Down

0 comments on commit 8129203

Please sign in to comment.