Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Nuphar] added Gemm-to-MatMul conversion in model editor #4691

Merged
merged 2 commits into from
Aug 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ void CountGemmOp(const onnxruntime::Node& node,

auto inputs = node.InputDefs();
CountMatrixArgs(inputs[0], inputs[1], node, graph_inputs, shape_func, node_use_counts);
// C's use cnt is fixed.
CountNodeArg(inputs[2], node, graph_inputs, node_use_counts, 1);
if (inputs.size() > 2) {
// C's use cnt is fixed.
CountNodeArg(inputs[2], node, graph_inputs, node_use_counts, 1);
}
}

void CountMatMulOp(const onnxruntime::Node& node,
Expand Down
34 changes: 22 additions & 12 deletions onnxruntime/core/providers/nuphar/compiler/func_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,30 +484,40 @@ static void FillScanExecInfo(NupharFuncInfo* func_info,

int ort_arg_index = gsl::narrow_cast<int>(ort_output_idx);
if (ort_output_idx < gsl::narrow<size_t>(num_state_variables)) {
auto key_iter = visited_output_def_indices.find(key);
// if ort_output_idx is a state output
if (visited_output_def_indices.count(key) != 0) {
if (key_iter != visited_output_def_indices.end()) {
// If state output is an alias
// record i_output for the lookup of the aliased output later
visited_output_state_func_indices.insert(std::make_pair(key, gsl::narrow<int>(func_info->func_input_count + tvm_output_idx)));

auto output_tvm_idx = key_iter->second - gsl::narrow_cast<int>(num_state_variables);

// also record ort_aliased_output_to_func_indices
func_info->ort_aliased_output_to_func_indices.push_back(std::make_pair(gsl::narrow<int>(ort_output_idx),
func_info->func_input_count + tvm_output_idx));
func_info->ort_aliased_output_to_func_indices.push_back(
std::make_pair(gsl::narrow<int>(ort_output_idx), func_info->func_input_count + output_tvm_idx));

scan_info->state_to_output_indices.push_back(output_tvm_idx);

if (visited_output_state_func_indices.count(key) != 0) {
// We could have multiple states that alias to the same output.
// We only record the first one and skip the rest.
continue;
} else {
// record i_output for the lookup of the aliased output later
visited_output_state_func_indices.insert(
std::make_pair(key, gsl::narrow<int>(func_info->func_input_count + output_tvm_idx)));

scan_info->state_to_output_indices.push_back(visited_output_def_indices[key] - gsl::narrow_cast<int>(num_state_variables));
// override ort_arg_index using the output index
ort_arg_index = visited_output_def_indices[key];
// override ort_arg_index using the output index
ort_arg_index = visited_output_def_indices[key];
}
} else {
// the state output not aliased(no scan output shares with it)
scan_info->state_to_output_indices.push_back(NupharFuncInfo::Index_NonAliasedOutput);
}
} else {
// if ort_output_idx is an output
if (visited_output_state_func_indices.count(key) != 0) {
if (source_def != nullptr) {
// skip a duplicated output, since it was counted in the duplicated state output previously
continue;
}
// skip a duplicated output, since it was counted in the duplicated state output previously
continue;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "core/codegen/mti/math/binary_ops.h"
#include "core/codegen/mti/math/gemm.h"
#include "core/codegen/mti/mti_tvm_utils.h"
#include "core/framework/op_kernel_info.h"
#include "core/providers/common.h"
#include "core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h"
Expand All @@ -22,21 +23,33 @@ Status NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(Gemm)::Evaluate(
tvm::Tensor Y;
auto& A = inputs[0];
auto& B = inputs[1];
auto& C = inputs[2];
tvm::Tensor C;

int64_t trans_a, trans_b;
float alpha, beta;
ORT_RETURN_IF_ERROR(info.GetAttr<int64_t>("transA", &trans_a));
ORT_RETURN_IF_ERROR(info.GetAttr<int64_t>("transB", &trans_b));
ORT_RETURN_IF_ERROR(info.GetAttr<float>("alpha", &alpha));
ORT_RETURN_IF_ERROR(info.GetAttr<float>("beta", &beta));

// bias is optional
if (inputs.size() < 3) {
beta = 0;
C = tvm_codegen::MakeZeroTensor({1}, A->dtype, node.Name() + "_zero");
} else {
C = inputs[2];
}

// use native sgemm for floating point
if (A->dtype == HalideIR::Float(32) &&
B->dtype == HalideIR::Float(32) &&
GemmExternCpu(A, B, Y, !!trans_a, !!trans_b, node.Name() + "_gemm")) {
if (beta != 0) {
tvm::Tensor beta_bias = (beta == 1) ? C : tvm_codegen::Mul(tvm::make_const(tvm::Float(32), beta), C);
Y = tvm_codegen::Add((alpha == 1) ? Y : tvm_codegen::Mul(tvm::make_const(tvm::Float(32), alpha), Y), beta_bias, node.Name() + "_add_bias");
Y = tvm_codegen::Add((alpha == 1) ? Y : tvm_codegen::Mul(tvm::make_const(tvm::Float(32), alpha), Y),
beta_bias, node.Name() + "_add_bias");
} else {
Y = (alpha == 1) ? Y : tvm_codegen::Mul(tvm::make_const(tvm::Float(32), alpha), Y);
}
outputs.push_back(Y);
return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ static bool MatMulF32ExternCPU(
const std::vector<int32_t>* p_permute_B = nullptr;
tvm::Tensor root_A = find_transposed_input(A, permute_A);
tvm::Tensor root_B = find_transposed_input(B, permute_B);
bool transA = false;
if (A->shape.size() == B->shape.size() && A->shape.size() >= 2) {
// currently only fuse Transpose into MatMul when rank(A) == rank(B)
// make sure no broadcasting in MatMul
Expand All @@ -146,6 +147,8 @@ static bool MatMulF32ExternCPU(
}
if (no_broadcast) {
if (CanPermuteBeFusedInMatMul(permute_A)) {
if (A != root_A)
transA = true;
A = root_A;
p_permute_A = &permute_A;
}
Expand All @@ -161,7 +164,7 @@ static bool MatMulF32ExternCPU(
// matmul with initializer, using transpose weights
auto layout_key = tvm_codegen::WeightLayoutTranspose2D::GetKey(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
auto actual_B = ctx_nuphar->ApplyWeightLayout(layout_key, B_name, B, true);
return nuphar::GemmExternCpu(A, actual_B, Y, false, true, B_name);
return nuphar::GemmExternCpu(A, actual_B, Y, transA, true, B_name);
} else {
return nuphar::MatMulExternCpu(A, B, Y, p_permute_A, p_permute_B, node.Name() + "_matmul_extern");
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/nuphar/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ class NupharKernelState {
NUPHAR_OP(Flatten, 11, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_OP(Floor, 6, DataTypeImpl::AllIEEEFloatTensorTypes()) \
NUPHAR_VERSIONED_OP(Gemm, 7, 8, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
NUPHAR_OP(Gemm, 9, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
NUPHAR_VERSIONED_OP(Gemm, 9, 10, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
NUPHAR_OP(Gemm, 11, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
NUPHAR_OP(GlobalAveragePool, 1, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
NUPHAR_OP(GlobalMaxPool, 1, DataTypeImpl::AllIEEEFloatTensorExceptHalfTypes()) \
NUPHAR_OP(Greater, 9, DataTypeImpl::AllFixedSizeTensorTypes()) \
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/nuphar/mti_x86/math/matmul_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ bool GemmExternCpu(
trans_a,
trans_b});
},
name, "", {})[0];
name + "_sgemm_cpu", "", {})[0];

return true;
}
Expand Down Expand Up @@ -303,7 +303,7 @@ bool MatMulExternCpu(
}
return topi::detail::call_packed(extern_args);
},
name, "", {})[0];
name + "_batched_matmul_cpu", "", {})[0];

return true;
}
Expand Down
97 changes: 97 additions & 0 deletions onnxruntime/core/providers/nuphar/scripts/model_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,99 @@ def convert_to_scan_model(input_model, output_model):

onnx.save(out_mp, output_model)

def gemm_to_matmul(node, nf, converted_initializers):
assert node.op_type == 'Gemm'

alpha = NodeFactory.get_attribute(node, 'alpha', 1.0)
beta = NodeFactory.get_attribute(node, 'beta', 1.0)
transA = NodeFactory.get_attribute(node, 'transA', 0)
transB = NodeFactory.get_attribute(node, 'transB', 0)

A = node.input[0]
B = node.input[1]
Y = node.output[0]

with nf.scoped_prefix(node.name) as scoped_prefix:
if alpha != 1.0:
alpha_name = node.name + '_Const_alpha'
nf.make_initializer(np.full((), alpha, dtype=np.float32), alpha_name)
alpha_A = nf.make_node('Mul', [alpha_name, A])
A = alpha_A.name

if transA:
if A in converted_initializers:
A = converted_initializers[A]
else:
A_initializer = nf.get_initializer(A)
# A is an initializer
if A_initializer is not None:
new_A = A + '_trans'
converted_initializers[A] = new_A
nf.make_initializer(np.transpose(A_initializer), new_A, in_main_graph=True)
nf.remove_initializer(A)
A = new_A
else:
A = nf.make_node('Transpose', A)
if transB:
if B in converted_initializers:
B = converted_initializers[B]
else:
B_initializer = nf.get_initializer(B)
# B is an initializer
if B_initializer is not None:
new_B = B + '_trans'
converted_initializers[B] = new_B
nf.make_initializer(np.transpose(B_initializer), new_B, in_main_graph=True)
nf.remove_initializer(B)
B = new_B
else:
B = nf.make_node('Transpose', B)

if len(node.input) != 3 or beta == 0.0:
nf.make_node('MatMul', [A, B], output_names=Y)
else:
AB = nf.make_node('MatMul', [A, B])
C = node.input[2]
if beta != 1.0:
beta_name = node.name + '_Const_beta'
nf.make_initializer(np.full((), beta, dtype=np.float32), beta_name)
C = nf.make_node('Mul', [beta_name, C])
nf.make_node('Add', [AB, C], output_names=Y)

def convert_gemm_to_matmul(input_model, output_model):
in_mp = onnx.load(input_model)
out_mp = onnx.ModelProto()
out_mp.CopyFrom(in_mp)
out_mp.ir_version = 5 # update ir version to avoid requirement of initializer in graph input
out_mp.graph.ClearField('node')
nf = NodeFactory(out_mp.graph)
# gemm_to_matmul will generate transposed weights if the corresponding input
# comes from initializer. We keep a map between the original and converted
# ones in case the original initializer is shared between Gemm ops
converted_initializers = {}

for in_n in in_mp.graph.node:
if in_n.op_type == 'Gemm':
gemm_to_matmul(in_n, nf, converted_initializers)
continue

out_n = out_mp.graph.node.add()
out_n.CopyFrom(in_n)
if in_n.op_type == 'Scan' or in_n.op_type == 'Loop':
in_subgraph = NodeFactory.get_attribute(in_n, 'body')
out_subgraph = NodeFactory.get_attribute(out_n, 'body')
out_subgraph.ClearField('node')
scan_nf = NodeFactory(out_mp.graph, out_subgraph)

for in_sn in in_subgraph.node:
if in_sn.op_type == 'Gemm':
gemm_to_matmul(in_sn, scan_nf, converted_initializers)
continue
out_sn = out_subgraph.node.add()
out_sn.CopyFrom(in_sn)

onnx.save(out_mp, output_model)

# Old models (ir_version < 4) is required to initializers in graph inputs
# This is optional for ir_version >= 4
def remove_initializers_from_inputs(input_model, output_model, remain_inputs=[]):
Expand Down Expand Up @@ -713,6 +806,7 @@ def parse_arguments():
parser.add_argument('--mode', help='The modification mode',
choices=['to_scan',
'opt_inproj',
'gemm_to_matmul',
'remove_initializers_from_inputs'])
parser.add_argument('--input', help='The input model file', default=None)
parser.add_argument('--output', help='The output model file', default=None)
Expand All @@ -725,6 +819,9 @@ def parse_arguments():
if args.mode == 'to_scan':
print('Convert LSTM/GRU/RNN to Scan...')
convert_to_scan_model(args.input, args.output)
elif args.mode == 'gemm_to_matmul':
print('Convert Gemm to MatMul')
convert_gemm_to_matmul(args.input, args.output)
elif args.mode == 'opt_inproj':
print('Optimize input projection in Scan...')
optimize_input_projection(args.input, args.output)
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/nuphar/scripts/model_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def convert_matmul_model(input_model, output_model, only_for_scan=False, share_i

out_n = out_mp.graph.node.add()
out_n.CopyFrom(in_n)
if in_n.op_type == 'Scan':
if in_n.op_type == 'Scan' or in_n.op_type == 'Loop':
in_subgraph = NodeFactory.get_attribute(in_n, 'body')
out_subgraph = NodeFactory.get_attribute(out_n, 'body')
out_subgraph.ClearField('node')
Expand Down Expand Up @@ -298,4 +298,4 @@ def parse_arguments():
print('Quantize MatMul to MatMulInteger...')
assert not args.export_qcfg_json or args.qcfg_json, "--qcfg_json must be specified when --export_qcfg_json is used"
convert_matmul_model(args.input, args.output, args.only_for_scan, args.share_input_quantization, args.default_qcfg, args.qcfg_json, args.export_qcfg_json)
print('Done!')
print('Done!')
14 changes: 14 additions & 0 deletions onnxruntime/test/providers/cpu/math/gemm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,5 +302,19 @@ TEST(GemmOpTest, GemmNoBiasOpset11) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kNGraphExecutionProvider, kTensorrtExecutionProvider});
}

TEST(GemmOpTest, GemmWithAlphaOpset11) {
OpTester test("Gemm", 11);

test.AddAttribute("alpha", 2.0f);

test.AddInput<float>("A", {2, 2},
{1.0f, 2.0f, 3.0f, 4.0f});
test.AddInput<float>("B", {2, 2}, std::vector<float>(4, 1.0f));
test.AddOutput<float>("Y", {2, 2},
{6.0f, 6.0f, 14.0f, 14.0f});
// NGraph and tensorRT don't seem to support missing bias
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kNGraphExecutionProvider, kTensorrtExecutionProvider});
}

} // namespace test
} // namespace onnxruntime
Loading