Skip to content

Commit

Permalink
[FoldScaleAxis] Support dense and bias_add op in fold scale axis (apa…
Browse files Browse the repository at this point in the history
…che#9838)

* [FoldScaleAxis] Support dense and bias_add op in fold scale axis

* fix lint
  • Loading branch information
shengxinhu authored and ylc committed Feb 16, 2022
1 parent e1b404f commit c7d6e57
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 0 deletions.
83 changes: 83 additions & 0 deletions src/relay/transforms/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,28 @@ RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", C
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv2DForwardRewrite);

// Dense send out requirement of axis folding.
Array<Message> DenseForwardPrep(const Call& call, const Message& out_message) {
return {Message({1}, false), NullValue<Message>()};
}

// Dense consumes the scale axis during transformation.
Expr DenseForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
const Message& message) {
const auto* sdata = new_args[0].as<ScaledExprNode>();
const auto* sweight = new_args[1].as<ScaledExprNode>();
if (sdata == nullptr) return Expr();
if (sweight != nullptr) return Expr();

Expr weight = Multiply(new_args[1], sdata->scale);
return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
}

RELAY_REGISTER_OP("nn.dense").set_attr<FForwardPrep>("FScaleAxisForwardPrep", DenseForwardPrep);

RELAY_REGISTER_OP("nn.dense")
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", DenseForwardRewrite);

Expr ForwardFoldScaleAxis(const Expr& data) {
auto message = ForwardPrep().Prepare(data);
for (const auto& m : message) {
Expand Down Expand Up @@ -996,6 +1018,67 @@ RELAY_REGISTER_OP("nn.conv2d")
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform);

Message BiasAddBackwardPrep(const Call& call, const Array<Message>& in_messages) {
const BiasAddAttrs* attrs = call->attrs.as<BiasAddAttrs>();
ICHECK(attrs);
if (in_messages[0].defined() && in_messages[0]->axes.size() == 1 &&
attrs->axis == static_cast<int>(in_messages[0]->axes[0]->value)) {
return in_messages[0];
} else {
return NullValue<Message>();
}
}

Expr BiasAddBackwardTransform(const Call& call, const Message& message, const Expr& scale,
const BackwardTransformer& transformer) {
if (!message.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
Message lhs_message = transformer->GetMessage(call->args[0]);
Message rhs_message = transformer->GetMessage(call->args[1]);
StructuralEqual equal;

if (lhs_message.defined()) {
ICHECK(equal(message->axes, lhs_message->axes));
Expr lhs = transformer->Transform(call->args[0], message, scale);
Expr rhs = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>());
rhs = Multiply(rhs, scale);
return Call(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else {
LOG(FATAL) << "outstanding scale";
return Expr();
}
}

RELAY_REGISTER_OP("nn.bias_add")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", BiasAddBackwardPrep);

RELAY_REGISTER_OP("nn.bias_add")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", BiasAddBackwardTransform);

// Dense send out requirement of axis folding.
Message DenseBackwardPrep(const Call& call, const Array<Message>& in_messages) {
return Message({1}, false);
}

// Dense consumes the sacle axis during trasformation.
Expr DenseBackwardTransform(const Call& call, const Message& message, const Expr& scale,
const BackwardTransformer& transformer) {
if (!message.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
Expr data = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>());
Expr weight = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>());
Expr wscale = ExpandBiasToMatchAxis(scale, 2, {0});
weight = Multiply(weight, wscale);
return Call(call->op, {data, weight}, call->attrs, call->type_args);
}

RELAY_REGISTER_OP("nn.dense").set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", DenseBackwardPrep);

RELAY_REGISTER_OP("nn.dense")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", DenseBackwardTransform);

Expr BackwardFoldScaleAxis(const Expr& data) {
return make_object<BackwardTransformerNode>()->Fold(data);
}
Expand Down
143 changes: 143 additions & 0 deletions tests/python/relay/test_pass_fold_scale_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,46 @@ def check(shape, channels, blocking):
check((2, 2, 10, 10, 2), 8, (2, 2))


def test_fold_fwd_dense():
"""dense testcase."""

def before(x, weight, in_bias, in_scale):
args = [x, weight, in_bias]
x = relay.multiply(x, in_scale)
x = relay.nn.relu(x)
x = relay.add(x, in_bias)
y = relay.nn.dense(x, weight)
return relay.Function(args, y)

def expected(x, weight, in_bias, in_scale):
# use a fixed order of args so alpha equal check can pass
args = [x, weight, in_bias]
x = relay.nn.relu(x)
in_bias = relay.divide(in_bias, in_scale)
x = relay.add(x, in_bias)
weight = relay.multiply(weight, in_scale)
y = relay.nn.dense(x, weight)
return relay.Function(args, y)

def check(data_shape, weight_shape):
x = relay.var("x", shape=data_shape)
weight = relay.var("weight", shape=weight_shape)
in_channels = data_shape[1]
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.const(_get_positive_scale((in_channels,)))
y1 = before(x, weight, in_bias, in_scale)
y1 = run_opt_pass(y1, transform.InferType())
y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
y1_expected = expected(x, weight, in_bias, in_scale)

y1_folded = run_opt_pass(y1_folded, transform.InferType())
y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert tvm.ir.structural_equal(y1_folded, y1_expected)

check((2, 4), (3, 4))
check((3, 5), (4, 5))


def test_fold_bwd_simple():
"""Simple testcase."""

Expand Down Expand Up @@ -888,15 +928,118 @@ def check(shape, channels, blocking):
check((2, 2, 10, 10, 2), 8, (2, 2))


def test_fold_bwd_dense():
"""dense testcase."""

def before(x, weight, in_bias, in_scale):
args = [x, weight, in_bias]
x = relay.nn.dense(x, weight)
x = relay.add(x, in_bias)
x = relay.nn.relu(x)
y = relay.multiply(x, in_scale)
return relay.Function(args, y)

def expected(x, weight, in_bias, in_scale):
# use a fixed order of args so alpha equal check can pass
args = [x, weight, in_bias]
scale = relay.expand_dims(in_scale, axis=1)
weight = relay.multiply(weight, scale)
x = relay.nn.dense(x, weight)
bias = relay.multiply(in_bias, in_scale)
x = relay.add(x, bias)
y = relay.nn.relu(x)
return relay.Function(args, y)

def check(data_shape, weight_shape):
x = relay.var("x", shape=data_shape)
weight = relay.var("weight", shape=weight_shape)
out_channels = weight_shape[0]
in_bias = relay.var("in_bias", shape=(out_channels,))
in_scale = relay.const(_get_positive_scale((out_channels,)))
y1 = before(x, weight, in_bias, in_scale)
y1 = run_opt_pass(y1, transform.InferType())
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
y1_expected = expected(x, weight, in_bias, in_scale)

y1_folded = run_opt_pass(y1_folded, transform.InferType())
y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert tvm.ir.structural_equal(y1_folded, y1_expected)

check((2, 4), (3, 4))
check((3, 5), (4, 5))


def test_fold_bwd_bias_add():
"""bias add testcase."""

def before(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias]
y = relay.nn.conv2d(
x,
conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
)
y = relay.nn.bias_add(y, out_bias)
y = relay.nn.relu(y)
y = relay.multiply(y, out_scale)
return relay.Function(args, y)

def expected(x, conv_weight, out_bias, out_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, out_bias]
squeezed_scale = relay.squeeze(out_scale, axis=[1, 2])
conv_weight = relay.multiply(
conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)
)

y = relay.nn.conv2d(
x,
conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
)

out_bias = relay.multiply(out_bias, squeezed_scale)
y = relay.nn.bias_add(y, out_bias)
y = relay.nn.relu(y)
return relay.Function(args, y)

def check(shape, channels):
x = relay.var("x", shape=shape)
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
y1 = before(x, weight, out_bias, out_scale, channels)
y1 = run_opt_pass(y1, transform.InferType())
type_dict = {x.name_hint: x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert tvm.ir.structural_equal(y1_folded, y1_expected)

check((2, 4, 10, 10), 4)


if __name__ == "__main__":
test_fold_fwd_simple()
test_fold_fwd_dual_path()
test_fold_fwd_fail()
test_fold_fwd_relu_fail()
test_fold_fwd_negative_scale()
test_fold_fwd_dense()
test_fold_bwd_simple()
test_fold_bwd_dual_path()
test_fold_bwd_dual_consumer()
test_fold_bwd_fail()
test_fold_bwd_relu_fail()
test_fold_bwd_negative_scale()
test_fold_bwd_dense()
test_fold_bwd_bias_add()

0 comments on commit c7d6e57

Please sign in to comment.