From a6337ca14fbbd17d2d7dc9ccafb1989791c15bc3 Mon Sep 17 00:00:00 2001 From: krishnaraj36 <45380557+krishnaraj36@users.noreply.github.com> Date: Wed, 28 Dec 2022 11:24:11 +0530 Subject: [PATCH] [CLML][RELAY] Enable Pad and Conv2d layer fusion (#13649) * [CLML][RELAY] Enable Pad and Conv2d layer fusion Enabled clml supported nn.pad+nn.conv2d fusion pattern in clml pattern table * Fix pad testcase attributes * Fix the lint error * Fix the lint error * Removed redundent check in clml pattern * Fix the lint error Co-authored-by: kvegiraj --- python/tvm/relay/op/contrib/clml.py | 21 +++++++++++++++++++++ src/relay/backend/contrib/clml/codegen.cc | 2 +- tests/python/contrib/test_clml/test_ops.py | 4 ++-- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index c3d4eb84700d..6453b8a06c9f 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -147,6 +147,23 @@ def conv_pattern(): pattern = pattern.optional(is_op("clip")) return pattern + def pad_conv_pattern(): + """Create a pad with convolution pattern.""" + pattern = is_op("nn.pad")(wildcard(), is_constant()) + pattern = is_op("nn.conv2d")(pattern, is_constant()) + pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant())) + pattern = pattern.optional(lambda x: is_op("add")(x, is_constant())) + pattern = pattern.optional( + lambda x: is_tuple_get_item( + is_op("nn.batch_norm")( + x, is_constant(), is_constant(), is_constant(), is_constant() + ) + ) + ) + pattern = pattern.optional(is_op("nn.relu")) + pattern = pattern.optional(is_op("clip")) + return pattern + def batch_norm_pattern(): """Create a batch norm pattern.""" pattern = is_op("nn.batch_norm")( @@ -200,9 +217,11 @@ def check_conv(extract): while call.op.name != "nn.conv2d": call = call.args[0] + attrs, args = call.attrs, call.args if attrs.data_layout != "NCHW": return False + if ( (not clip_found) and (attrs.kernel_size[0] == 3) @@ -211,6 +230,7 @@ def check_conv(extract): and (attrs.channels == attrs.groups) ): return False + data_typ = args[0].checked_type kernel_typ = args[1].checked_type is_depthwise = is_depthwise_conv2d( @@ -246,6 +266,7 @@ def check_default_op(extract): return True return [ + ("clml.pad_conv2d", pad_conv_pattern(), check_conv), ("clml.conv2d", conv_pattern(), check_conv), ("clml.dense", dense_pattern(), check_default_op), ("clml.pad", pad_pattern(), check_pad_op), diff --git a/src/relay/backend/contrib/clml/codegen.cc b/src/relay/backend/contrib/clml/codegen.cc index 9ecec0c4531f..167c48e1baf5 100644 --- a/src/relay/backend/contrib/clml/codegen.cc +++ b/src/relay/backend/contrib/clml/codegen.cc @@ -83,7 +83,7 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer { ICHECK(comp.defined()) << "CLML JSON runtime only supports composite functions."; const std::string name = comp.value(); std::shared_ptr json_node; - if (name == "clml.conv2d") { + if (name == "clml.conv2d" || name == "clml.pad_conv2d") { json_node = CreateCompositeConvJSONNode(cn); } else if (name == "clml.batch_norm") { json_node = CreateBatchNormJSONNode(cn); diff --git a/tests/python/contrib/test_clml/test_ops.py b/tests/python/contrib/test_clml/test_ops.py index d2431d2dfd3b..da09715fbe4c 100644 --- a/tests/python/contrib/test_clml/test_ops.py +++ b/tests/python/contrib/test_clml/test_ops.py @@ -45,7 +45,7 @@ def _get_conv_model( a = relay.var(next(iter(var)), shape=shape, dtype=dtype) input_arr = var[next(iter(var))] if has_pad: - p = ((0, 0), (padding[0], padding[0]), (padding[1], padding[1]), (0, 0)) + p = ((0, 0), (0, 0), (padding[0], padding[0]), (padding[1], padding[1])) a = relay.nn.pad(a, pad_width=p) padding = (0, 0, 0, 0) else: @@ -97,7 +97,7 @@ def test_conv2d(device, dtype): trials = [ # Normal convolution [3, 3, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, False, False)], - [2, 1, (2, 2), (1, 1), (1, 1), 7, (15, 16, 12), (False, False, True)], + [2, 1, (2, 2), (1, 1), (1, 1), 7, (15, 16, 12), (True, False, True)], [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False)], [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, True)], # Normal convolution