Skip to content

Commit

Permalink
Fix error in creation of gemm impl
Browse files Browse the repository at this point in the history
(fused inputs should not be added to "input" of kernel arg)
  • Loading branch information
yeonbok committed Aug 9, 2022
1 parent 3d8b166 commit 88b1e16
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,18 @@ struct gemm_impl : typed_primitive_impl_ocl<gemm> {
};
const auto input_layouts = get_gemm_input_layouts(impl_param->input_layouts, impl_param->output_layout);
const auto output_layout = get_gemm_output_layout(input_layouts, impl_param->output_layout);

auto first_fused_input_idx = input_layouts.size();
const auto fused_descs = impl_param->fused_desc;
if (fused_descs.size() > 0) {
first_fused_input_idx = fused_descs[0].dep_start_idx;
}
auto gemm_params = get_default_params<kernel_selector::gemm_params>(*impl_param, 1);
auto gemm_optional_params =
get_default_optional_params<kernel_selector::gemm_optional_params>(arg.get_program());

gemm_params.inputs.clear();
for (size_t i = 0; i < input_layouts.size(); i++) {
for (size_t i = 0; i < std::min(input_layouts.size(), first_fused_input_idx); i++) {
gemm_params.inputs.push_back(convert_data_tensor(input_layouts[i]));
}
gemm_params.outputs[0] = convert_data_tensor(output_layout);
Expand Down

0 comments on commit 88b1e16

Please sign in to comment.