diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index b3ef478f201d..9b6b45240a50 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -51,6 +51,7 @@ logger = logging.getLogger("DNNL") +supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None] def _register_external_op_helper(op_name, supported=True): @@ -120,6 +121,8 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None): conv_out : CallPattern Call node sequence. """ + if with_eltwise not in supported_post_elts: + raise ValueError("Unsupported eltwise post-op: %s" % with_eltwise) data = wildcard() weight = wildcard() bias = wildcard() @@ -128,8 +131,11 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None): conv_out = is_op("add")(conv, bias) else: conv_out = conv - if with_eltwise: - return is_op(with_eltwise)(conv_out) + if with_eltwise == "swish": + sig_out = is_op("sigmoid")(conv_out) + conv_out = is_op("multiply")(conv_out, sig_out) + elif with_eltwise: + conv_out = is_op(with_eltwise)(conv_out) return conv_out @@ -147,6 +153,8 @@ def make_dense_pattern(with_bias=True, with_eltwise=None): dense_out : CallPattern Call node sequence. """ + if with_eltwise not in supported_post_elts: + raise ValueError("Unsupported eltwise post-op: %s" % with_eltwise) data = wildcard() weight = wildcard() bias = wildcard() @@ -165,6 +173,9 @@ def make_dense_pattern(with_bias=True, with_eltwise=None): added_erf_val = is_op("add")(erf_val, const2) mul_val = is_op("multiply")(dense_out, added_erf_val) dense_out = is_op("multiply")(mul_val, const3) + elif with_eltwise == "swish": + sig_out = is_op("sigmoid")(dense_out) + dense_out = is_op("multiply")(dense_out, sig_out) elif with_eltwise: dense_out = is_op(with_eltwise)(dense_out) return dense_out @@ -191,6 +202,7 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise): pat_name = "dnnl.deconv" + op_name.split("_")[0][-2::] pat_name += "_bias" if with_bias else "" pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else "" + pat_name = pat_name.replace("_swish", "_sigmoid_mul") if "conv" in op_name: dnnl_pattern = (pat_name, make_conv_pattern(op_name, with_bias, with_eltwise)) elif op_name == "nn.dense": @@ -282,7 +294,7 @@ def pattern_table(): dnnl_patterns.append(make_qnn_conv2d_pattern()) dnnl_patterns.append(make_qnn_dense_pattern()) - elt_list = ["nn.relu", "tanh", "sigmoid", "gelu", None] + elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None] for with_bias in [True, False]: for elt in elt_list: if not with_bias and not elt: @@ -380,6 +392,8 @@ def get_shape(tensor): if isinstance(tensor, tvm.ir.container.Array): return tensor[-1].shape if isinstance(tensor, relay.expr.Call): + if tensor.op.name == "multiply": + return tensor.type_args[0].shape return tensor.checked_type.shape raise TypeError("Unsupport data type: %s" % type(tensor)) @@ -395,6 +409,8 @@ def get_dtype(tensor): if isinstance(tensor, tvm.ir.container.Array): return tensor[-1].dtype if isinstance(tensor, relay.expr.Call): + if tensor.op.name == "multiply": + return tensor.type_args[0].dtype return tensor.checked_type.dtype raise TypeError("Unsupport data type: %s" % type(tensor)) diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 2f47c23a7cf9..4abfc9d9b136 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -470,6 +470,8 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { {"relu", "nn.relu"}, {"tanh", "tanh"}, {"sigmoid", "sigmoid"}, + {"clip", "clip"}, + {"mul", "multiply"}, {"nn.deconv2d", "nn.conv2d_transpose"}, {"nn.deconv3d", "nn.conv3d_transpose"}, }; @@ -566,6 +568,13 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { "kernel", /* op_type_ */ inputs, 1 /* num_outputs_ */); SetCallNodeAttribute(node, call); + // If has post-op `clip`. Assume the last op is clip, add clip's attrs to the pattern attrs. + if (name.find("_clip") != std::string::npos) { + auto clip_call = cn->op.as()->body.as(); + ICHECK(IsOp(clip_call, "clip")); + SetCallNodeAttribute(node, clip_call); + } + // For QNN. for (const auto& kvp : extra_attrs) node->SetAttr(kvp.first, kvp.second); return AddNode(node, GetRef(cn)); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index d6fae8c72b5e..57c066131181 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -470,6 +470,10 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth, current_call->args[valid_node_idx].as()) { valid_node_idx++; } + while (valid_node_idx < current_call->args.size() && + !(IsOp(current_call->args[valid_node_idx].as(), expected_op_names[depth - 1]))) { + valid_node_idx++; + } const auto* next_call = current_call->args[valid_node_idx].as(); return GetRootCall(next_call, depth - 1, expected_op_names); } diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index a46f170fea94..6c0fd64066e5 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -189,6 +189,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::regex relu_pat(".*_relu.*"); std::regex tanh_pat(".*_tanh.*"); std::regex sigmoid_pat(".*_sigmoid.*"); + std::regex clip_pat(".*_clip.*"); std::regex gelu_pat(".*_gelu.*"); // Parsing post-ops. @@ -199,8 +200,17 @@ class DNNLJSONRuntime : public JSONRuntimeBase { if (std::regex_match(op_name, tanh_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_tanh, 0.f, 0.f); } + if (std::regex_match(op_name, clip_pat)) { + float a_min = GetNodeAttr(nodes_[nid], "a_min"); + float a_max = GetNodeAttr(nodes_[nid], "a_max"); + ops.append_eltwise(1.f, dnnl::algorithm::eltwise_clip, a_min, a_max); + } if (std::regex_match(op_name, sigmoid_pat)) { - ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f); + if (op_name.find("_sigmoid_mul") != std::string::npos) { + ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f); + } else { + ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f); + } } if (std::regex_match(op_name, gelu_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 078483798c6d..6c7034741a37 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -192,7 +192,6 @@ def run_and_verify(mod, input, params, target, run_module, subgraph_num=None, te if use_dnnl: processed_mod = partition_for_dnnl(processed_mod, params, alter_layout) check_dnnl_used(processed_mod) - with tvm.transform.PassContext(opt_level=3): func = relay.create_executor( mode, mod=processed_mod, device=dev, target=target @@ -237,6 +236,23 @@ def run_and_verify_func( ) +def add_activation(activation, out, dic, param_lst): + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + elif activation == "clip": + return relay.clip(out, 0.0, 6.0), dic, param_lst + elif activation == "swish": + sig_out = relay.sigmoid(out) + out = relay.multiply(out, sig_out) + return out, dic, param_lst + else: + return out, dic, param_lst + + def get_conv1d( x_shape=((1, 3, 224)), k_shape=(16, 3, 3), @@ -262,15 +278,7 @@ def get_conv1d( ) dic = {"x": x_shape, "kernel": k_shape} param_lst = ["kernel"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv1d_bias(x_shape=(1, 3, 224), k_shape=(10, 3, 3), activation=None, dtype="float32"): @@ -279,15 +287,7 @@ def get_conv1d_bias(x_shape=(1, 3, 224), k_shape=(10, 3, 3), activation=None, dt out = relay.nn.bias_add(conv, bias) dic["bias"] = (k_shape[0],) param_lst += ["bias"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv1d_bias_bn_relu(x_shape=(1, 3, 224), k_shape=(10, 3, 3), dtype="float32"): @@ -334,15 +334,7 @@ def get_conv2d( ) dic = {"x": x_shape, "kernel": k_shape} param_lst = ["kernel"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv2d_transpose( @@ -367,15 +359,7 @@ def get_conv2d_transpose( ) dic = {"x": x_shape, "kernel": k_shape} param_lst = ["kernel"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv2d_weights_const( @@ -412,15 +396,7 @@ def get_conv2d_bias( out = relay.nn.bias_add(conv, bias) dic["bias"] = (k_shape[0],) param_lst += ["bias"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv2d_transpose_bias( @@ -431,15 +407,7 @@ def get_conv2d_transpose_bias( out = relay.nn.bias_add(conv, bias) dic["bias"] = (k_shape[1],) param_lst += ["bias"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"): @@ -503,15 +471,7 @@ def get_conv3d( ) dic = {"x": x_shape, "kernel": k_shape} param_lst = ["kernel"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv3d_transpose( @@ -542,15 +502,7 @@ def get_conv3d_transpose( ) dic = {"x": x_shape, "kernel": k_shape} param_lst = ["kernel"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv3d_bias( @@ -561,15 +513,7 @@ def get_conv3d_bias( out = relay.nn.bias_add(conv, bias) dic["bias"] = (k_shape[0],) param_lst += ["bias"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv3d_transpose_bias( @@ -580,15 +524,7 @@ def get_conv3d_transpose_bias( out = relay.nn.bias_add(conv, bias) dic["bias"] = (k_shape[1],) param_lst += ["bias"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def gelu_helper(data): @@ -797,7 +733,7 @@ def test_conv2d_weights_const(run_module, dtype="float32"): def test_conv2d_pattern(run_module, dtype="float32"): x_shape = (1, 32, 8, 8) k_shape = (16, 32, 3, 3) - activation_lst = [None, "relu", "tanh", "sigmoid"] + activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"] for a in activation_lst: conv2d, dic, param_lst = get_conv2d(x_shape, k_shape, activation=a, dtype=dtype) conv2d = tvm.IRModule.from_expr(conv2d) @@ -839,7 +775,7 @@ def test_conv2d_transpose(run_module, dtype="float32"): def test_conv2d_transpose_pattern(run_module, dtype="float32"): - activation_lst = [None, "relu", "tanh", "sigmoid"] + activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"] for a in activation_lst: conv2d, dic, param_lst = get_conv2d_transpose(activation=a, dtype=dtype) conv2d = tvm.IRModule.from_expr(conv2d) @@ -872,7 +808,7 @@ def test_conv3d(run_module, dtype="float32"): def test_conv3d_pattern(run_module, dtype="float32"): - activation_lst = [None, "relu", "tanh", "sigmoid"] + activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"] for a in activation_lst: conv3d, dic, param_lst = get_conv3d(activation=a, dtype=dtype) conv3d = tvm.IRModule.from_expr(conv3d) @@ -905,7 +841,7 @@ def test_conv3d_transpose(run_module, dtype="float32"): def test_conv3d_transpose_pattern(run_module, dtype="float32"): - activation_lst = [None, "relu", "tanh", "sigmoid"] + activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"] for a in activation_lst: conv3d, dic, param_lst = get_conv3d_transpose(activation=a, dtype=dtype) conv3d = tvm.IRModule.from_expr(conv3d) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 58b41189a0f0..4b7ac92136e9 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -919,6 +919,7 @@ def expected(): def test_dnnl_fuse(): dnnl_patterns = get_pattern_table("dnnl") + dnnl_pat_dic = dict(dnnl_patterns) ( conv2d_bias_relu_pat, conv2d_bias_sigmoid_pat, @@ -926,11 +927,26 @@ def test_dnnl_fuse(): conv2d_relu_pat, conv2d_sigmoid_pat, ) = ( - dnnl_patterns[3], - dnnl_patterns[15], - dnnl_patterns[22], - dnnl_patterns[28], - dnnl_patterns[40], + ( + "dnnl.conv2d_bias_relu", + dnnl_pat_dic["dnnl.conv2d_bias_relu"], + ), + ( + "dnnl.conv2d_bias_sigmoid", + dnnl_pat_dic["dnnl.conv2d_bias_sigmoid"], + ), + ( + "dnnl.conv2d_bias", + dnnl_pat_dic["dnnl.conv2d_bias"], + ), + ( + "dnnl.conv2d_relu", + dnnl_pat_dic["dnnl.conv2d_relu"], + ), + ( + "dnnl.conv2d_sigmoid", + dnnl_pat_dic["dnnl.conv2d_sigmoid"], + ), ) def get_blocks(