From 88b1e16dc726b21ab23f896a02ee29ad1f07b334 Mon Sep 17 00:00:00 2001 From: Taylor Yeonbok Lee Date: Tue, 9 Aug 2022 17:37:14 +0900 Subject: [PATCH] Fix error in creation of gemm impl (fused inputs should not be added to "input" of kernel arg) --- src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp index 1a83e0be0aa64a..c642c50d7d070d 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp @@ -115,12 +115,18 @@ struct gemm_impl : typed_primitive_impl_ocl { }; const auto input_layouts = get_gemm_input_layouts(impl_param->input_layouts, impl_param->output_layout); const auto output_layout = get_gemm_output_layout(input_layouts, impl_param->output_layout); + + auto first_fused_input_idx = input_layouts.size(); + const auto fused_descs = impl_param->fused_desc; + if (fused_descs.size() > 0) { + first_fused_input_idx = fused_descs[0].dep_start_idx; + } auto gemm_params = get_default_params(*impl_param, 1); auto gemm_optional_params = get_default_optional_params(arg.get_program()); gemm_params.inputs.clear(); - for (size_t i = 0; i < input_layouts.size(); i++) { + for (size_t i = 0; i < std::min(input_layouts.size(), first_fused_input_idx); i++) { gemm_params.inputs.push_back(convert_data_tensor(input_layouts[i])); } gemm_params.outputs[0] = convert_data_tensor(output_layout);