From c9430d3c1b982e6de8289b88efb918d04f8328ee Mon Sep 17 00:00:00 2001 From: alexey-varyzgin Date: Thu, 27 Jan 2022 13:22:42 +0300 Subject: [PATCH] gemm_convolution: memory access fix --- src/cpu/x64/jit_gemm_convolution_utils.cpp | 95 ++++++++++++++++++++-- 1 file changed, 86 insertions(+), 9 deletions(-) diff --git a/src/cpu/x64/jit_gemm_convolution_utils.cpp b/src/cpu/x64/jit_gemm_convolution_utils.cpp index 542f0e565d8..03d91598d64 100644 --- a/src/cpu/x64/jit_gemm_convolution_utils.cpp +++ b/src/cpu/x64/jit_gemm_convolution_utils.cpp @@ -84,6 +84,8 @@ struct jit_pp_kernel_t : pp_kernel_t, public jit_generator { private: void generate() override; + void copy_elems(const Xbyak::Reg64 &dst, const Xbyak::Reg64 &src, const Xbyak::Reg64 &size, const int elemSize); + void foreach (const Xbyak::Reg64 &idx, size_t step, const Xbyak::Reg64 &end, std::function && fn); struct ker_args_t { float *dst; @@ -142,6 +144,47 @@ struct jit_pp_kernel_t : pp_kernel_t, public jit_generator { Vmm vreg_bias(int idx) { return Vmm(idx_vreg_bias(idx)); }; }; +template +void jit_pp_kernel_t::foreach (const Xbyak::Reg64 &idx, size_t step, + const Xbyak::Reg64 &end, std::function && fn) +{ + Xbyak::Label loop, exit; + + L(loop); + cmp(idx, end); + jge(exit); + + fn(idx); + + add(idx, step); + jmp(loop); + L(exit); +} + +template +void jit_pp_kernel_t::copy_elems(const Xbyak::Reg64 &dst, + const Xbyak::Reg64& src, const Xbyak::Reg64& size, const int elemSize) { + push(rsi); + push(r13); + + xor_(rsi, rsi); + + if (elemSize == 1) { + foreach(rsi, 1, size, [&, this](const Xbyak::Reg64& idx) { + mov(r13b, byte[src + idx * elemSize]); + mov(byte[dst + idx * elemSize], r13b); + }); + } else if (elemSize == 4) { + foreach(rsi, 1, size, [&, this](const Xbyak::Reg64& idx) { + mov(r13d, dword[src + idx * elemSize]); + mov(dword[dst + idx * elemSize], r13d); + }); + } + + pop(r13); + pop(rsi); +} + template void jit_pp_kernel_t::generate() { using namespace Xbyak; @@ -161,7 +204,18 @@ void jit_pp_kernel_t::generate() { mov(reg_table, l_table); } - auto apply_post_ops = [&]() { + auto store_to_stack = [&](const Reg64 &from, const Reg64 &size) { + sub(rsp, vlen * sizeof(float)); + mov(r8, rsp); + copy_elems(r8, from, size, sizeof(float)); + }; + + auto load_from_stack = [&](const Vmm &to) { + uni_vmovups(to, ptr[rsp]); + add(rsp, vlen * sizeof(float)); + }; + + auto apply_post_ops = [&](bool apply_mask) { int eltwise_inj_idx = 0; int depthwise_inj_idx = 0; auto vreg_dst_ = vreg_dst(0); @@ -176,8 +230,20 @@ void jit_pp_kernel_t::generate() { mov(reg_d_bias, reinterpret_cast(post_op.depthwise.biases_data)); lea(reg_d_weights, ptr[reg_d_weights + reg_oc_offset * sizeof(float)]); lea(reg_d_bias, ptr[reg_d_bias + reg_oc_offset * sizeof(float)]); + if (apply_mask) { + store_to_stack(reg_d_weights, reg_tmp); + mov(reg_d_weights, rsp); + + if (post_op.depthwise.alg == dnnl_depthwise_scale_shift) { + store_to_stack(reg_d_bias, reg_tmp); + mov(reg_d_bias, rsp); + } + } jit_depthwise_injectors_[depthwise_inj_idx]->compute_vector_range(vreg_dst_.getIdx(), vreg_dst_.getIdx() + 1, reg_d_weights, reg_d_bias, true); + if (apply_mask) { + add(rsp, (post_op.depthwise.alg == dnnl_depthwise_scale_shift ? 2 : 1) * vlen * sizeof(float)); + } depthwise_inj_idx++; } else if (post_op.is_quantization()) { bool do_dequantization = post_op.quantization.alg == alg_kind::quantization_quantize_dequantize; @@ -243,6 +309,10 @@ void jit_pp_kernel_t::generate() { // Load accumulated value, convert to float, apply bias (if any), scaling, // and eltwise (if any); then convert to destination type and store auto compute = [&](bool apply_mask) { + if (apply_mask) { + push(r8); + } + auto dst_addr = ptr[reg_dst]; auto vreg_dst_ = vreg_dst(0); if (isa == avx512_common) { @@ -251,11 +321,8 @@ void jit_pp_kernel_t::generate() { uni_vmovups(vreg_dst_, dst_addr); } else { if (apply_mask) { - if (isa != sse41) { - uni_vblendvps(vreg_dst_, vreg_zero, dst_addr, vreg_mask); - } else { - uni_vmovups(vreg_dst_, dst_addr); - } + store_to_stack(reg_dst, reg_tmp); + load_from_stack(vreg_dst_); } else { uni_vmovups(vreg_dst_, dst_addr); } @@ -270,7 +337,7 @@ void jit_pp_kernel_t::generate() { uni_vaddps(vreg_dst_, vreg_dst_, vreg_bias_); } - apply_post_ops(); + apply_post_ops(apply_mask); if (isa == avx512_common) { uni_vmovups(dst_addr, vreg_dst_); @@ -279,13 +346,20 @@ void jit_pp_kernel_t::generate() { if (isa != sse41) { vmaskmovps(dst_addr, vreg_mask, vreg_dst_); } else { - lea(reg_ptr_maskmovdqu_dst, dst_addr); - maskmovdqu(vreg_dst_, vreg_mask); + sub(rsp, vlen * sizeof(float)); + mov(r8, rsp); + uni_vmovups(ptr[r8], vreg_dst_); + copy_elems(reg_dst, r8, reg_tmp, sizeof(float)); + add(rsp, vlen * sizeof(float)); } } else { uni_vmovups(dst_addr, vreg_dst_); } } + + if (apply_mask) { + pop(r8); + } }; Label loop_end; @@ -303,6 +377,9 @@ void jit_pp_kernel_t::generate() { cmp(reg_len, vlen); jge(loop, T_NEAR); } + + cmp(reg_tmp, 0); + je(loop_end, T_NEAR); L(loop_tail); mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift