diff --git a/src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp b/src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp index 136f476880a..93e438c04ab 100644 --- a/src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp +++ b/src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp @@ -51,10 +51,10 @@ jit_brdgmm_kernel_base_t::jit_brdgmm_kernel_base_t(const brgemm_t &abrd) = {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc}; const binary_injector::rhs_arg_static_params_t rhs_sp { - static_cast(vmm_b().getIdx()), r14, r15, preserve_gpr, - preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), - GET_OFF(data_C_ptr_), dst_md_wrapper, - static_cast(n_vlen_tail()), k_mask, + static_cast(vmm_b().getIdx()), r14, r15, r13, + preserve_gpr, preserve_vmm, + GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_), + dst_md_wrapper, static_cast(n_vlen_tail()), k_mask, use_exact_tail_scalar_bcast}; const binary_injector::static_params_t bsp { this->param1, enabled_bcast_strategy, rhs_sp}; diff --git a/src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp b/src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp index 820ec8c3794..b18faeebff8 100644 --- a/src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp +++ b/src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp @@ -62,7 +62,7 @@ struct jit_brgemm_amx_uker_base_t : public jit_generator { broadcasting_strategy_t::no_broadcast}; const binary_injector::rhs_arg_static_params_t rhs_sp { static_cast(Xbyak::Zmm(1).getIdx()), this->r14, - this->r15, preserve_gpr, preserve_vmm, + this->r15, this->r13, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_), dst_md_wrapper, static_cast(brg.ldb_tail), ld_tail_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp index 25eb484e3cb..67cc005eb0e 100644 --- a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp +++ b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp @@ -68,7 +68,7 @@ struct jit_brgemm_kernel_t : public jit_generator { broadcasting_strategy_t::no_broadcast}; const binary_injector::rhs_arg_static_params_t rhs_sp { static_cast(Vmm(1).getIdx()), this->r14, this->r15, - preserve_gpr, preserve_vmm, + this->r13, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_), dst_md_wrapper, static_cast(brg.ldb_tail), ld_tail_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/gemm_bf16_convolution.cpp b/src/cpu/x64/gemm_bf16_convolution.cpp index 1a4c171f208..f067e310be6 100644 --- a/src/cpu/x64/gemm_bf16_convolution.cpp +++ b/src/cpu/x64/gemm_bf16_convolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2021 Intel Corporation +* Copyright 2019-2022 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -111,7 +111,7 @@ gemm_bf16_convolution_fwd_t::pp_ker_t::pp_ker_t(const pd_t *pd) static constexpr size_t tail_size = 0; static constexpr bool use_exact_tail_scalar_bcast = false; const binary_injector::rhs_arg_static_params_t rhs_sp { - helper_vmm_idx, r13, r14, preserve_gpr, + helper_vmm_idx, r13, r14, r15, preserve_gpr, preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec), PARAM_OFF(dst_orig), memory_desc_wrapper(pd->dst_md()), tail_size, kreg_rem_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/injectors/jit_uni_binary_injector.cpp b/src/cpu/x64/injectors/jit_uni_binary_injector.cpp index d345074f775..e65887971b5 100644 --- a/src/cpu/x64/injectors/jit_uni_binary_injector.cpp +++ b/src/cpu/x64/injectors/jit_uni_binary_injector.cpp @@ -216,98 +216,62 @@ static_params_t::static_params_t(const Xbyak::Reg64 ¶m1, rhs_arg_static_params_t::rhs_arg_static_params_t( std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, - const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, - bool preserve_vmm_helper, std::size_t abi_param_offset, - const memory_desc_wrapper &dst_d, std::size_t tail_size, - bool use_exact_tail_scalar_bcast) - : rhs_arg_static_params_t(rhs_dt_helper_vmm_idx, rhs_addr_reg, - rhs_helper_reg, preserve_gpr_helpers, preserve_vmm_helper, - abi_param_offset, 0, dst_d, tail_size, Xbyak::Opmask(2), - use_exact_tail_scalar_bcast, rhs_helper_reg, - false /*is_opmask_set*/, false /*is_dst_orig_set*/) {} - -rhs_arg_static_params_t::rhs_arg_static_params_t( - std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, - const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, + const Xbyak::Reg64 &rhs_helper_reg, + const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers, bool preserve_vmm_helper, std::size_t abi_param_offset, std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d, std::size_t tail_size, bool use_exact_tail_scalar_bcast) : rhs_arg_static_params_t(rhs_dt_helper_vmm_idx, rhs_addr_reg, - rhs_helper_reg, preserve_gpr_helpers, preserve_vmm_helper, - abi_param_offset, dst_orig_offset, dst_d, tail_size, - Xbyak::Opmask(2), use_exact_tail_scalar_bcast, rhs_helper_reg, - false /*is_opmask_set*/, true /*is_dst_orig_set*/) {} - -rhs_arg_static_params_t::rhs_arg_static_params_t( - std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, - const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, - bool preserve_vmm_helper, std::size_t abi_param_offset, - const memory_desc_wrapper &dst_d, std::size_t tail_size, - const Xbyak::Opmask &tail_opmask, bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx) - : rhs_arg_static_params_t(rhs_dt_helper_vmm_idx, rhs_addr_reg, - rhs_helper_reg, preserve_gpr_helpers, preserve_vmm_helper, - abi_param_offset, 0, dst_d, tail_size, tail_opmask, - use_exact_tail_scalar_bcast, rhs_helper_reg, true /*is_opmask_set*/, - false /*is_dst_orig_set*/) { - this->rhs_prelu_helper_vmm_idx = rhs_prelu_helper_vmm_idx; -} + rhs_helper_reg, rhs_addr_cache_reg, preserve_gpr_helpers, + preserve_vmm_helper, abi_param_offset, dst_orig_offset, dst_d, + tail_size, Xbyak::Opmask(2), use_exact_tail_scalar_bcast, + rhs_helper_reg, false /*is_opmask_set*/) {} rhs_arg_static_params_t::rhs_arg_static_params_t( std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, - const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, + const Xbyak::Reg64 &rhs_helper_reg, + const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers, bool preserve_vmm_helper, std::size_t abi_param_offset, std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d, std::size_t tail_size, const Xbyak::Opmask &tail_opmask, bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx) : rhs_arg_static_params_t(rhs_dt_helper_vmm_idx, rhs_addr_reg, - rhs_helper_reg, preserve_gpr_helpers, preserve_vmm_helper, - abi_param_offset, dst_orig_offset, dst_d, tail_size, tail_opmask, - use_exact_tail_scalar_bcast, rhs_helper_reg, true /*is_opmask_set*/, - true /*is_dst_orig_set*/) { - this->rhs_prelu_helper_vmm_idx = rhs_prelu_helper_vmm_idx; -} - -rhs_arg_static_params_t::rhs_arg_static_params_t( - std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, - const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, - bool preserve_vmm_helper, std::size_t abi_param_offset, - const memory_desc_wrapper &dst_d, std::size_t tail_size, - const Xbyak::Opmask &tail_opmask, const Xbyak::Reg64 ®_tail_size, - bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx) - : rhs_arg_static_params_t(rhs_dt_helper_vmm_idx, rhs_addr_reg, - rhs_helper_reg, preserve_gpr_helpers, preserve_vmm_helper, - abi_param_offset, 0, dst_d, tail_size, tail_opmask, - use_exact_tail_scalar_bcast, reg_tail_size, true /*is_opmask_set*/, - false /*is_dst_orig_set*/) { + rhs_helper_reg, rhs_addr_cache_reg, preserve_gpr_helpers, + preserve_vmm_helper, abi_param_offset, dst_orig_offset, dst_d, + tail_size, tail_opmask, use_exact_tail_scalar_bcast, rhs_helper_reg, + true /*is_opmask_set*/) { this->rhs_prelu_helper_vmm_idx = rhs_prelu_helper_vmm_idx; } rhs_arg_static_params_t::rhs_arg_static_params_t( std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, - const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, + const Xbyak::Reg64 &rhs_helper_reg, + const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers, bool preserve_vmm_helper, std::size_t abi_param_offset, std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d, std::size_t tail_size, const Xbyak::Opmask &tail_opmask, const Xbyak::Reg64 ®_tail_size, bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx) : rhs_arg_static_params_t(rhs_dt_helper_vmm_idx, rhs_addr_reg, - rhs_helper_reg, preserve_gpr_helpers, preserve_vmm_helper, - abi_param_offset, dst_orig_offset, dst_d, tail_size, tail_opmask, - use_exact_tail_scalar_bcast, reg_tail_size, true /*is_opmask_set*/, - true /*is_dst_orig_set*/) { + rhs_helper_reg, rhs_addr_cache_reg, preserve_gpr_helpers, + preserve_vmm_helper, abi_param_offset, dst_orig_offset, dst_d, + tail_size, tail_opmask, use_exact_tail_scalar_bcast, reg_tail_size, + true /*is_opmask_set*/) { this->rhs_prelu_helper_vmm_idx = rhs_prelu_helper_vmm_idx; } rhs_arg_static_params_t::rhs_arg_static_params_t( std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, - const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, + const Xbyak::Reg64 &rhs_helper_reg, + const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers, bool preserve_vmm_helper, std::size_t abi_param_offset, std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d, std::size_t tail_size, const Xbyak::Opmask &tail_opmask, bool use_exact_tail_scalar_bcast, const Xbyak::Reg64 ®_tail_size, - bool is_opmask_set, bool is_dst_orig_set) + bool is_opmask_set) : rhs_dt_helper_vmm_idx(rhs_dt_helper_vmm_idx) , rhs_addr_reg(rhs_addr_reg) , rhs_helper_reg(rhs_helper_reg) + , rhs_addr_cache_reg(rhs_addr_cache_reg) , preserve_gpr_helpers(preserve_gpr_helpers) , preserve_vmm_helper(preserve_vmm_helper) , abi_param_offset(abi_param_offset) @@ -318,8 +282,7 @@ rhs_arg_static_params_t::rhs_arg_static_params_t( , use_exact_tail_scalar_bcast(use_exact_tail_scalar_bcast) , reg_tail_size(reg_tail_size) , is_tail(tail_size) - , is_opmask_set_(is_opmask_set) - , is_dst_orig_set_(is_dst_orig_set) {} + , is_opmask_set_(is_opmask_set) {} template jit_uni_binary_injector_t::jit_uni_binary_injector_t( @@ -345,45 +308,14 @@ static bool rhs_arg_params_differ(size_t vmm_idx1, size_t vmm_idx2, const auto &out_addr = rhs_arg_params.vmm_idx_to_out_addr; const auto &out_reg = rhs_arg_params.vmm_idx_to_out_reg; - - const auto &out_elem_off_addr = rhs_arg_params.vmm_idx_to_out_elem_off_addr; const auto &out_elem_off_val = rhs_arg_params.vmm_idx_to_out_elem_off_val; - const auto &out_off_oprnd = rhs_arg_params.vmm_idx_to_out_off_oprnd; - const auto &oc_off_addr = rhs_arg_params.vmm_idx_to_oc_elem_off_addr; - const auto &oc_off_val = rhs_arg_params.vmm_idx_to_oc_elem_off_val; - const auto &oc_off_oprnd = rhs_arg_params.vmm_idx_to_oc_off_oprnd; - const auto &sp_off_addr = rhs_arg_params.vmm_idx_to_sp_elem_off_addr; - const auto &sp_off_val = rhs_arg_params.vmm_idx_to_sp_elem_off_val; - const auto &sp_off_oprnd = rhs_arg_params.vmm_idx_to_sp_off_oprnd; - - if (rhs_broadcasting_strategy == broadcasting_strategy_t::scalar) { - return false; - } else if (rhs_broadcasting_strategy - == broadcasting_strategy_t::no_broadcast) { - return params_differ(out_addr, vmm_idx1, vmm_idx2) - || params_differ(out_reg, vmm_idx1, vmm_idx2) - || params_differ(out_elem_off_addr, vmm_idx1, vmm_idx2) - || params_differ(out_elem_off_val, vmm_idx1, vmm_idx2) - || params_differ(out_off_oprnd, vmm_idx1, vmm_idx2); - } else if (rhs_broadcasting_strategy == broadcasting_strategy_t::per_oc - || rhs_broadcasting_strategy - == broadcasting_strategy_t::per_oc_spatial) { - return params_differ(out_addr, vmm_idx1, vmm_idx2) - || params_differ(out_reg, vmm_idx1, vmm_idx2) - || params_differ(out_elem_off_val, vmm_idx1, vmm_idx2) - || params_differ(oc_off_addr, vmm_idx1, vmm_idx2) - || params_differ(oc_off_val, vmm_idx1, vmm_idx2) - || params_differ(oc_off_oprnd, vmm_idx1, vmm_idx2); - } else if (rhs_broadcasting_strategy - == broadcasting_strategy_t::per_mb_spatial) { + + if (rhs_broadcasting_strategy != broadcasting_strategy_t::scalar) { return params_differ(out_addr, vmm_idx1, vmm_idx2) || params_differ(out_reg, vmm_idx1, vmm_idx2) - || params_differ(out_elem_off_val, vmm_idx1, vmm_idx2) - || params_differ(sp_off_addr, vmm_idx1, vmm_idx2) - || params_differ(sp_off_val, vmm_idx1, vmm_idx2) - || params_differ(sp_off_oprnd, vmm_idx1, vmm_idx2); + || params_differ(out_elem_off_val, vmm_idx1, vmm_idx2); } - return true; + return false; } template @@ -511,8 +443,7 @@ void jit_uni_binary_injector_t::compute_vector_range( const int blk_size = dst_d.blocking_desc().inner_blks[0]; const bool use_offset_conversions = (!rhs_arg_params.vmm_idx_to_out_addr.empty() - || !rhs_arg_params.vmm_idx_to_out_reg.empty()) - && rhs_arg_static_params_.is_dst_orig_set(); + || !rhs_arg_params.vmm_idx_to_out_reg.empty()); const bool should_preserve_oc_offset_conversion_regs = use_offset_conversions && utils::one_of(rhs_broadcasting_strategy, @@ -538,6 +469,8 @@ void jit_uni_binary_injector_t::compute_vector_range( {rhs_arg_static_params_.rhs_addr_reg, rhs_arg_static_params_ .rhs_helper_reg, + rhs_arg_static_params_ + .rhs_addr_cache_reg, host_->rax, host_->rdx, host_->r8}) : rhs_arg_static_params_.preserve_gpr_helpers && should_preserve_mb_sp_offset_conversion_regs @@ -546,18 +479,20 @@ void jit_uni_binary_injector_t::compute_vector_range( .rhs_addr_reg, rhs_arg_static_params_ .rhs_helper_reg, + rhs_arg_static_params_ + .rhs_addr_cache_reg, host_->rax, host_->rdx, host_->r8, host_->r9}) : rhs_arg_static_params_ .preserve_gpr_helpers ? std::initializer_list< - Xbyak::Reg64>( - {rhs_arg_static_params_ - .rhs_addr_reg, - rhs_arg_static_params_ - .rhs_helper_reg, - host_->rax, - host_->rdx}) + Xbyak::Reg64>({rhs_arg_static_params_ + .rhs_addr_reg, + rhs_arg_static_params_ + .rhs_helper_reg, + rhs_arg_static_params_ + .rhs_addr_cache_reg, + host_->rax, host_->rdx}) : should_preserve_w_or_oc_offset_conversion_regs ? std::initializer_list< Xbyak::Reg64>( @@ -588,11 +523,12 @@ void jit_uni_binary_injector_t::compute_vector_range( // Phase 3 Apply binary post-op over all vmms. for (const auto vmm_idx : vmm_idxs) { - if (vmm_idx == start_idx + const bool is_start_idx = vmm_idx == start_idx; + if (is_start_idx || rhs_arg_params_differ(vmm_idx, vmm_idx - 1, rhs_arg_params, rhs_broadcasting_strategy)) { rhs_arg_addr = prepare_rhs_arg_addr(vmm_idx, rhs_arg_idx, post_op, - rhs_arg_params, rhs_broadcasting_strategy); + rhs_arg_params, rhs_broadcasting_strategy, is_start_idx); } const auto local_vmm_preservation = should_preserve_vmm( @@ -629,7 +565,8 @@ Xbyak::Address jit_uni_binary_injector_t::prepare_rhs_arg_addr( std::size_t vmm_idx, std::size_t rhs_arg_idx, const dnnl_post_ops::entry_t &post_op, const rhs_arg_dynamic_params_t &rhs_arg_params, - const broadcasting_strategy_t rhs_broadcasting_strategy) const { + const broadcasting_strategy_t rhs_broadcasting_strategy, + bool is_first) const { static constexpr auto rhs_arg_ptr_size = sizeof(const void *); const auto &rhs_addr_reg = rhs_arg_static_params_.rhs_addr_reg; @@ -638,40 +575,28 @@ Xbyak::Address jit_uni_binary_injector_t::prepare_rhs_arg_addr( const auto rhs_arg_elem_size = types::data_type_size(post_op.binary.src1_desc.data_type); - host_->mov(rhs_addr_reg, host_->ptr[param1_ + abi_param_offset]); - host_->mov(rhs_addr_reg, - host_->ptr[rhs_addr_reg + rhs_arg_idx * rhs_arg_ptr_size]); + if (is_first) { + host_->mov(rhs_addr_reg, host_->ptr[param1_ + abi_param_offset]); + host_->mov(rhs_addr_reg, + host_->ptr[rhs_addr_reg + rhs_arg_idx * rhs_arg_ptr_size]); + } switch (rhs_broadcasting_strategy) { case broadcasting_strategy_t::scalar: return host_->ptr_b[rhs_addr_reg]; case broadcasting_strategy_t::no_broadcast: { - append_offset_from_operand(rhs_arg_params.vmm_idx_to_out_off_oprnd, - vmm_idx, rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size); - append_offset_under_mem_addr( - rhs_arg_params.vmm_idx_to_out_elem_off_addr, vmm_idx, - rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size); - append_value_offset(rhs_arg_params.vmm_idx_to_out_elem_off_val, - vmm_idx, rhs_addr_reg, rhs_arg_elem_size); append_no_broadcast_offset(rhs_arg_params.vmm_idx_to_out_addr, rhs_arg_params.vmm_idx_to_out_reg, rhs_arg_params.vmm_idx_to_out_elem_off_val, vmm_idx, - rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size); + rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size, is_first); return host_->ptr[rhs_addr_reg]; } case broadcasting_strategy_t::per_oc: case broadcasting_strategy_t::per_oc_spatial: { - append_offset_from_operand(rhs_arg_params.vmm_idx_to_oc_off_oprnd, - vmm_idx, rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size); - append_offset_under_mem_addr( - rhs_arg_params.vmm_idx_to_oc_elem_off_addr, vmm_idx, - rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size); - append_value_offset(rhs_arg_params.vmm_idx_to_oc_elem_off_val, - vmm_idx, rhs_addr_reg, rhs_arg_elem_size); append_oc_offset(rhs_arg_params.vmm_idx_to_out_addr, rhs_arg_params.vmm_idx_to_out_reg, rhs_arg_params.vmm_idx_to_out_elem_off_val, vmm_idx, - rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size); + rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size, is_first); return rhs_broadcasting_strategy == broadcasting_strategy_t::per_oc_spatial @@ -679,47 +604,26 @@ Xbyak::Address jit_uni_binary_injector_t::prepare_rhs_arg_addr( : host_->ptr[rhs_addr_reg]; } case broadcasting_strategy_t::per_mb_spatial: { - append_offset_from_operand(rhs_arg_params.vmm_idx_to_sp_off_oprnd, - vmm_idx, rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size); - append_offset_under_mem_addr( - rhs_arg_params.vmm_idx_to_sp_elem_off_addr, vmm_idx, - rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size); - append_value_offset(rhs_arg_params.vmm_idx_to_sp_elem_off_val, - vmm_idx, rhs_addr_reg, rhs_arg_elem_size); append_mb_sp_offset(rhs_arg_params.vmm_idx_to_out_addr, rhs_arg_params.vmm_idx_to_out_reg, rhs_arg_params.vmm_idx_to_out_elem_off_val, vmm_idx, - rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size); + rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size, is_first); return host_->ptr[rhs_addr_reg]; } case broadcasting_strategy_t::per_mb_w: { - append_offset_from_operand(rhs_arg_params.vmm_idx_to_mb_w_off_oprnd, - vmm_idx, rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size); - append_offset_under_mem_addr( - rhs_arg_params.vmm_idx_to_mb_w_elem_off_addr, vmm_idx, - rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size); - append_value_offset(rhs_arg_params.vmm_idx_to_mb_w_elem_off_val, - vmm_idx, rhs_addr_reg, rhs_arg_elem_size); append_mb_w_offset(rhs_arg_params.vmm_idx_to_out_addr, rhs_arg_params.vmm_idx_to_out_reg, rhs_arg_params.vmm_idx_to_out_elem_off_val, vmm_idx, - rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size); + rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size, is_first); return host_->ptr[rhs_addr_reg]; } case broadcasting_strategy_t::per_w: { - append_offset_from_operand(rhs_arg_params.vmm_idx_to_w_off_oprnd, - vmm_idx, rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size); - append_offset_under_mem_addr( - rhs_arg_params.vmm_idx_to_w_elem_off_addr, vmm_idx, - rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size); - append_value_offset(rhs_arg_params.vmm_idx_to_w_elem_off_val, - vmm_idx, rhs_addr_reg, rhs_arg_elem_size); append_w_offset(rhs_arg_params.vmm_idx_to_out_addr, rhs_arg_params.vmm_idx_to_out_reg, rhs_arg_params.vmm_idx_to_out_elem_off_val, vmm_idx, - rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size); + rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size, is_first); return host_->ptr[rhs_addr_reg]; } @@ -729,64 +633,13 @@ Xbyak::Address jit_uni_binary_injector_t::prepare_rhs_arg_addr( return host_->ptr[rhs_addr_reg]; } -template -void jit_uni_binary_injector_t::append_offset_from_operand( - const std::map &vmm_idx_to_elem_operand_off, - int vmm_idx, const Xbyak::Reg64 &addr_reg, const Xbyak::Reg64 &tmp_reg, - std::size_t elem_size_bytes) const { - - const auto it_operand_off = vmm_idx_to_elem_operand_off.find(vmm_idx); - if (it_operand_off != vmm_idx_to_elem_operand_off.end() - && !rhs_arg_static_params_.is_dst_orig_set()) { - if (elem_size_bytes == 1) { - host_->add(addr_reg, it_operand_off->second); - } else { - const int shift_val = std::log2(elem_size_bytes); - host_->mov(tmp_reg, it_operand_off->second); - host_->sal(tmp_reg, shift_val); - host_->add(addr_reg, tmp_reg); - } - } -} - -template -void jit_uni_binary_injector_t::append_offset_under_mem_addr( - const std::map &vmm_idx_to_elem_addr_off, - int vmm_idx, const Xbyak::Reg64 &addr_reg, const Xbyak::Reg64 &tmp_reg, - std::size_t elem_size_bytes) const { - - const auto it_off_addr = vmm_idx_to_elem_addr_off.find(vmm_idx); - if (it_off_addr != vmm_idx_to_elem_addr_off.end() - && !rhs_arg_static_params_.is_dst_orig_set()) { - if (elem_size_bytes == 1) { - host_->add(addr_reg, it_off_addr->second); - } else { - const int shift_val = std::log2(elem_size_bytes); - host_->mov(tmp_reg, it_off_addr->second); - host_->sal(tmp_reg, shift_val); - host_->add(addr_reg, tmp_reg); - } - } -} - -template -void jit_uni_binary_injector_t::append_value_offset( - const std::map &vmm_idx_to_elem_val_off, int vmm_idx, - const Xbyak::Reg64 &addr_reg, std::size_t elem_size_bytes) const { - - const auto it_off_val = vmm_idx_to_elem_val_off.find(vmm_idx); - if (it_off_val != vmm_idx_to_elem_val_off.end() - && !rhs_arg_static_params_.is_dst_orig_set()) - host_->add(addr_reg, it_off_val->second * elem_size_bytes); -} - template void jit_uni_binary_injector_t::append_no_broadcast_offset( const std::map &vmm_idx_to_out_addr, const std::map &vmm_idx_to_out_reg, const std::map &vmm_idx_to_out_elem_off_val, int vmm_idx, const Xbyak::Reg64 &addr_reg, const Xbyak::Reg64 &tmp_reg, - std::size_t elem_size_bytes) const { + std::size_t elem_size_bytes, bool is_first) const { const auto it_out_addr = vmm_idx_to_out_addr.find(vmm_idx); const auto it_out_reg = vmm_idx_to_out_reg.find(vmm_idx); @@ -794,31 +647,35 @@ void jit_uni_binary_injector_t::append_no_broadcast_offset( const bool is_out_addr = it_out_addr != vmm_idx_to_out_addr.end(); const bool is_out_reg = it_out_reg != vmm_idx_to_out_reg.end(); if (is_out_addr || is_out_reg) { - assert(rhs_arg_static_params_.is_dst_orig_set() - && "dst base addr offset not set"); Xbyak::Address out_addr = is_out_addr ? it_out_addr->second : host_->ptr[it_out_reg->second]; const auto it_off_val = vmm_idx_to_out_elem_off_val.find(vmm_idx); - calculate_no_broadcast(out_addr, - it_off_val != vmm_idx_to_out_elem_off_val.end() - ? it_off_val->second - : 0, - tmp_reg); - - if (elem_size_bytes > 1) { - const int shift_val = std::log2(elem_size_bytes); - host_->sal(tmp_reg, shift_val); + const auto &addr_cache_reg = rhs_arg_static_params_.rhs_addr_cache_reg; + + if (is_first) { + calculate_no_broadcast_base(out_addr, tmp_reg); + if (elem_size_bytes > 1) { + const int shift_val = std::log2(elem_size_bytes); + host_->sal(tmp_reg, shift_val); + } + host_->add(addr_reg, tmp_reg); + host_->mov(addr_cache_reg, addr_reg); + } else { + host_->mov(addr_reg, addr_cache_reg); + } + + if (it_off_val != vmm_idx_to_out_elem_off_val.end()) { + calculate_no_broadcast_partial( + it_off_val->second, tmp_reg, elem_size_bytes); + host_->add(addr_reg, tmp_reg); } - host_->add(addr_reg, tmp_reg); } } template -void jit_uni_binary_injector_t::calculate_no_broadcast( - Xbyak::Address addr, std::size_t offset, - const Xbyak::Reg64 &out_reg) const { +void jit_uni_binary_injector_t::calculate_no_broadcast_base( + Xbyak::Address addr, const Xbyak::Reg64 &out_reg) const { host_->lea(out_reg, addr); - if (offset > 0) host_->add(out_reg, offset); host_->sub(out_reg, host_->ptr[param1_ + rhs_arg_static_params_.dst_orig_offset]); host_->shr(out_reg, @@ -826,13 +683,24 @@ void jit_uni_binary_injector_t::calculate_no_broadcast( rhs_arg_static_params_.dst_d.data_type()))); } +template +void jit_uni_binary_injector_t::calculate_no_broadcast_partial( + const std::size_t offset, const Xbyak::Reg64 &out_reg, + std::size_t elem_size_bytes) const { + const auto offset_adj = offset >> math::ilog2q(types::data_type_size( + rhs_arg_static_params_.dst_d.data_type())); + host_->mov(out_reg, + elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes) + : offset_adj); +} + template void jit_uni_binary_injector_t::append_oc_offset( const std::map &vmm_idx_to_out_addr, const std::map &vmm_idx_to_out_reg, const std::map &vmm_idx_to_out_elem_off_val, int vmm_idx, const Xbyak::Reg64 &addr_reg, const Xbyak::Reg64 &tmp_reg, - std::size_t elem_size_bytes) const { + std::size_t elem_size_bytes, bool is_first) const { const auto it_out_addr = vmm_idx_to_out_addr.find(vmm_idx); const auto it_out_reg = vmm_idx_to_out_reg.find(vmm_idx); @@ -841,60 +709,84 @@ void jit_uni_binary_injector_t::append_oc_offset( const bool is_out_reg = it_out_reg != vmm_idx_to_out_reg.end(); if (is_out_addr || is_out_reg) { - assert(rhs_arg_static_params_.is_dst_orig_set() - && "dst base addr offset not set"); Xbyak::Address out_addr = is_out_addr ? it_out_addr->second : host_->ptr[it_out_reg->second]; const auto it_off_val = vmm_idx_to_out_elem_off_val.find(vmm_idx); - calculate_no_broadcast(out_addr, - it_off_val != vmm_idx_to_out_elem_off_val.end() - ? it_off_val->second - : 0, - tmp_reg); - - const auto rax = host_->rax; - const auto rdx = host_->rdx; - const auto r8 = host_->r8; - - const injector_utils::conditional_register_preserve_guard_t - register_guard {is_out_reg ? utils::one_of( - it_out_reg->second, rax, rdx, r8) - : false, - host_, {it_out_reg->second}}; + const auto &addr_cache_reg = rhs_arg_static_params_.rhs_addr_cache_reg; const auto dst_d = rhs_arg_static_params_.dst_d; const auto strides = dst_d.blocking_desc().strides; const auto layout = injector_utils::get_layout_type(dst_d); - switch (layout) { - case injector_utils::layout_t::ncsp: - calculate_oc_ncsp(strides, tmp_reg); - break; - case injector_utils::layout_t::c_blocked: - calculate_oc_blocked(strides, tmp_reg); - break; - case injector_utils::layout_t::nspc: - calculate_oc_nspc(strides, tmp_reg); - break; - case injector_utils::layout_t::cspn: - calculate_oc_cspn(strides, tmp_reg); - break; - default: assert(!"Unknown layout"); - } + if (is_first) { + calculate_no_broadcast_base(out_addr, tmp_reg); + + const auto rax = host_->rax; + const auto rdx = host_->rdx; + const auto r8 = host_->r8; + + const injector_utils::conditional_register_preserve_guard_t + register_guard {is_out_reg ? utils::one_of( + it_out_reg->second, rax, rdx, r8) + : false, + host_, {it_out_reg->second}}; + + switch (layout) { + case injector_utils::layout_t::ncsp: + calculate_oc_ncsp_base(strides, tmp_reg); + break; + case injector_utils::layout_t::c_blocked: + calculate_oc_blocked_base(strides, tmp_reg); + break; + case injector_utils::layout_t::nspc: + calculate_oc_nspc_base(strides, tmp_reg); + break; + case injector_utils::layout_t::cspn: + calculate_oc_cspn_base(strides, tmp_reg); + break; + default: assert(!"Unknown layout"); + } - if (elem_size_bytes == 1) { - host_->add(addr_reg, rax); + if (elem_size_bytes == 1) { + host_->add(addr_reg, rax); + } else { + const int shift_val = std::log2(elem_size_bytes); + host_->mov(tmp_reg, rax); + host_->sal(tmp_reg, shift_val); + host_->add(addr_reg, tmp_reg); + } + host_->mov(addr_cache_reg, addr_reg); } else { - const int shift_val = std::log2(elem_size_bytes); - host_->mov(tmp_reg, rax); - host_->sal(tmp_reg, shift_val); + host_->mov(addr_reg, addr_cache_reg); + } + + if (it_off_val != vmm_idx_to_out_elem_off_val.end()) { + switch (layout) { + case injector_utils::layout_t::ncsp: + calculate_oc_ncsp_partial(strides, it_off_val->second, + tmp_reg, elem_size_bytes); + break; + case injector_utils::layout_t::c_blocked: + calculate_oc_blocked_partial(strides, it_off_val->second, + tmp_reg, elem_size_bytes); + break; + case injector_utils::layout_t::nspc: + calculate_oc_nspc_partial(strides, it_off_val->second, + tmp_reg, elem_size_bytes); + break; + case injector_utils::layout_t::cspn: + calculate_oc_cspn_partial(strides, it_off_val->second, + tmp_reg, elem_size_bytes); + break; + default: assert(!"Unknown layout"); + } host_->add(addr_reg, tmp_reg); } } } template -void jit_uni_binary_injector_t::calculate_oc_ncsp( +void jit_uni_binary_injector_t::calculate_oc_ncsp_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const { // c = (offset % strides[0]) / strides[1] // output = rax @@ -912,7 +804,22 @@ void jit_uni_binary_injector_t::calculate_oc_ncsp( } template -void jit_uni_binary_injector_t::calculate_oc_blocked( +void jit_uni_binary_injector_t::calculate_oc_ncsp_partial( + const dim_t *strides, const std::size_t offset, + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const { + // c = (offset % strides[0]) / strides[1] + const auto offset_adj + = ((offset >> math::ilog2q(types::data_type_size( + rhs_arg_static_params_.dst_d.data_type()))) + % strides[0]) + / strides[1]; + host_->mov(tmp_reg, + elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes) + : offset_adj); +} + +template +void jit_uni_binary_injector_t::calculate_oc_blocked_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const { // c = ((offset % strides[0]) / strides[1]) * strides[ndims - 1] + offset % blk_size // output = rax @@ -924,7 +831,7 @@ void jit_uni_binary_injector_t::calculate_oc_blocked( const auto rdx = host_->rdx; const auto r8 = host_->r8; - calculate_oc_ncsp(strides, tmp_reg); + calculate_oc_ncsp_base(strides, tmp_reg); if (blk_size > simd_w) { // extract c % blk_size @@ -943,7 +850,23 @@ void jit_uni_binary_injector_t::calculate_oc_blocked( } template -void jit_uni_binary_injector_t::calculate_oc_nspc( +void jit_uni_binary_injector_t::calculate_oc_blocked_partial( + const dim_t *strides, const std::size_t offset, + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const { + // c = ((offset % strides[0]) / strides[1]) * strides[ndims - 1] + offset % blk_size + const auto dst_d = rhs_arg_static_params_.dst_d; + const int blk_size = dst_d.blocking_desc().inner_blks[0]; + const auto offset_shr = offset >> math::ilog2q(types::data_type_size( + rhs_arg_static_params_.dst_d.data_type())); + const auto offset_adj = ((offset_shr % strides[0]) / strides[1]) * blk_size + + offset_shr % blk_size; + host_->mov(tmp_reg, + elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes) + : offset_adj); +} + +template +void jit_uni_binary_injector_t::calculate_oc_nspc_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const { // c = offset % C // output = rax @@ -959,7 +882,21 @@ void jit_uni_binary_injector_t::calculate_oc_nspc( } template -void jit_uni_binary_injector_t::calculate_oc_cspn( +void jit_uni_binary_injector_t::calculate_oc_nspc_partial( + const dim_t *strides, const std::size_t offset, + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const { + // c = offset % C + const auto C = rhs_arg_static_params_.dst_d.dims()[1]; + const auto offset_adj = (offset >> math::ilog2q(types::data_type_size( + rhs_arg_static_params_.dst_d.data_type()))) + % C; + host_->mov(tmp_reg, + elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes) + : offset_adj); +} + +template +void jit_uni_binary_injector_t::calculate_oc_cspn_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const { // c = offset / strides[1] // output = rax @@ -972,13 +909,26 @@ void jit_uni_binary_injector_t::calculate_oc_cspn( host_->div(tmp_reg); } +template +void jit_uni_binary_injector_t::calculate_oc_cspn_partial( + const dim_t *strides, const std::size_t offset, + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const { + // c = offset / strides[1] + const auto offset_adj = (offset >> math::ilog2q(types::data_type_size( + rhs_arg_static_params_.dst_d.data_type()))) + / strides[1]; + host_->mov(tmp_reg, + elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes) + : offset_adj); +} + template void jit_uni_binary_injector_t::append_mb_sp_offset( const std::map &vmm_idx_to_out_addr, const std::map &vmm_idx_to_out_reg, const std::map &vmm_idx_to_out_elem_off_val, int vmm_idx, const Xbyak::Reg64 &addr_reg, const Xbyak::Reg64 &tmp_reg, - std::size_t elem_size_bytes) const { + std::size_t elem_size_bytes, bool is_first) const { const auto it_out_addr = vmm_idx_to_out_addr.find(vmm_idx); const auto it_out_reg = vmm_idx_to_out_reg.find(vmm_idx); @@ -987,61 +937,86 @@ void jit_uni_binary_injector_t::append_mb_sp_offset( const bool is_out_reg = it_out_reg != vmm_idx_to_out_reg.end(); if (is_out_addr || is_out_reg) { - assert(rhs_arg_static_params_.is_dst_orig_set() - && "dst base addr offset not set"); Xbyak::Address out_addr = is_out_addr ? it_out_addr->second : host_->ptr[it_out_reg->second]; const auto it_off_val = vmm_idx_to_out_elem_off_val.find(vmm_idx); - calculate_no_broadcast(out_addr, - it_off_val != vmm_idx_to_out_elem_off_val.end() - ? it_off_val->second - : 0, - tmp_reg); - - const auto rax = host_->rax; - const auto rdx = host_->rdx; - const auto r8 = host_->r8; - const auto r9 = host_->r9; - - const injector_utils::conditional_register_preserve_guard_t - register_guard {is_out_reg ? utils::one_of( - it_out_reg->second, rax, rdx, r8, r9) - : false, - host_, {it_out_reg->second}}; + const auto &addr_cache_reg = rhs_arg_static_params_.rhs_addr_cache_reg; const auto dst_d = rhs_arg_static_params_.dst_d; const auto strides = dst_d.blocking_desc().strides; const auto layout = injector_utils::get_layout_type(dst_d); - switch (layout) { - case injector_utils::layout_t::ncsp: - calculate_mb_sp_ncsp(strides, tmp_reg); - break; - case injector_utils::layout_t::c_blocked: - calculate_mb_sp_blocked(strides, tmp_reg); - break; - case injector_utils::layout_t::nspc: - calculate_mb_sp_nspc(strides, tmp_reg); - break; - case injector_utils::layout_t::cspn: - calculate_mb_sp_cspn(strides, tmp_reg); - break; - default: assert(!"Unknown layout"); - } + if (is_first) { + calculate_no_broadcast_base(out_addr, tmp_reg); + + const auto rax = host_->rax; + const auto rdx = host_->rdx; + const auto r8 = host_->r8; + const auto r9 = host_->r9; + + const injector_utils::conditional_register_preserve_guard_t + register_guard {is_out_reg + ? utils::one_of(it_out_reg->second, rax, + rdx, r8, r9) + : false, + host_, {it_out_reg->second}}; + + switch (layout) { + case injector_utils::layout_t::ncsp: + calculate_mb_sp_ncsp_base(strides, tmp_reg); + break; + case injector_utils::layout_t::c_blocked: + calculate_mb_sp_blocked_base(strides, tmp_reg); + break; + case injector_utils::layout_t::nspc: + calculate_mb_sp_nspc_base(strides, tmp_reg); + break; + case injector_utils::layout_t::cspn: + calculate_mb_sp_cspn_base(strides, tmp_reg); + break; + default: assert(!"Unknown layout"); + } - if (elem_size_bytes == 1) { - host_->add(addr_reg, rax); + if (elem_size_bytes == 1) { + host_->add(addr_reg, rax); + } else { + const int shift_val = std::log2(elem_size_bytes); + host_->mov(tmp_reg, rax); + host_->sal(tmp_reg, shift_val); + host_->add(addr_reg, tmp_reg); + } + host_->mov(addr_cache_reg, addr_reg); } else { - const int shift_val = std::log2(elem_size_bytes); - host_->mov(tmp_reg, rax); - host_->sal(tmp_reg, shift_val); + host_->mov(addr_reg, addr_cache_reg); + } + + if (it_off_val != vmm_idx_to_out_elem_off_val.end()) { + switch (layout) { + case injector_utils::layout_t::ncsp: + calculate_mb_sp_ncsp_partial(strides, it_off_val->second, + tmp_reg, elem_size_bytes); + break; + case injector_utils::layout_t::c_blocked: + calculate_mb_sp_blocked_partial(strides, it_off_val->second, + tmp_reg, elem_size_bytes); + break; + case injector_utils::layout_t::nspc: + calculate_mb_sp_nspc_partial(strides, it_off_val->second, + tmp_reg, elem_size_bytes); + break; + case injector_utils::layout_t::cspn: + calculate_mb_sp_cspn_partial(strides, it_off_val->second, + tmp_reg, elem_size_bytes); + break; + default: assert(!"Unknown layout"); + } host_->add(addr_reg, tmp_reg); } } } template -void jit_uni_binary_injector_t::calculate_mb_sp_ncsp( +void jit_uni_binary_injector_t::calculate_mb_sp_ncsp_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const { // offset = (n * stride_n) + (c * stride_c) + (d * stride_d) + (h * stride_h) + (w * stride_w) // mb_sp_off = (n * (stride_n/C)) + (d * stride_d) + (h * stride_h) + (w * stride_w) @@ -1085,7 +1060,33 @@ void jit_uni_binary_injector_t::calculate_mb_sp_ncsp( } template -void jit_uni_binary_injector_t::calculate_mb_sp_blocked( +void jit_uni_binary_injector_t::calculate_mb_sp_ncsp_partial( + const dim_t *strides, const std::size_t offset, + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const { + // offset = (n * stride_n) + (c * stride_c) + (d * stride_d) + (h * stride_h) + (w * stride_w) + // mb_sp_off = (n * (stride_n/C)) + (d * stride_d) + (h * stride_h) + (w * stride_w) + // mb_sp_off = offset - (c * stride_c) - (n * (C - 1)DHW) + + const auto dst_d = rhs_arg_static_params_.dst_d; + const auto ndims = dst_d.ndims(); + const auto C_padded = dst_d.padded_dims()[1]; + const auto D = (ndims >= 5) ? dst_d.dims()[ndims - 3] : 1; + const auto H = (ndims >= 4) ? dst_d.dims()[ndims - 2] : 1; + const auto W = (ndims >= 3) ? dst_d.dims()[ndims - 1] : 1; + + const auto offset_shr = offset >> math::ilog2q(types::data_type_size( + rhs_arg_static_params_.dst_d.data_type())); + const auto c = (offset_shr % strides[0]) / strides[1]; + const auto n = offset_shr / strides[0]; + const auto offset_adj + = offset_shr - (c * strides[1]) - (n * (C_padded - 1) * D * H * W); + host_->mov(tmp_reg, + elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes) + : offset_adj); +} + +template +void jit_uni_binary_injector_t::calculate_mb_sp_blocked_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const { // mb_sp_off = offset - (c * stride_c) - (n * (C - 1)DHW) - c % blk_size // output = rax @@ -1109,11 +1110,36 @@ void jit_uni_binary_injector_t::calculate_mb_sp_blocked( host_->sub(tmp_reg, rdx); } - calculate_mb_sp_ncsp(strides, tmp_reg); + calculate_mb_sp_ncsp_base(strides, tmp_reg); +} + +template +void jit_uni_binary_injector_t::calculate_mb_sp_blocked_partial( + const dim_t *strides, const std::size_t offset, + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const { + // mb_sp_off = offset - (c * stride_c) - (n * (C - 1)DHW) - c % blk_size + + const auto dst_d = rhs_arg_static_params_.dst_d; + const auto ndims = dst_d.ndims(); + const auto C_padded = dst_d.padded_dims()[1]; + const auto D = (ndims >= 5) ? dst_d.dims()[ndims - 3] : 1; + const auto H = (ndims >= 4) ? dst_d.dims()[ndims - 2] : 1; + const auto W = (ndims >= 3) ? dst_d.dims()[ndims - 1] : 1; + const int blk_size = dst_d.blocking_desc().inner_blks[0]; + + const auto offset_shr = offset >> math::ilog2q(types::data_type_size( + rhs_arg_static_params_.dst_d.data_type())); + const auto c = (offset_shr % strides[0]) / strides[1]; + const auto n = offset_shr / strides[0]; + const auto offset_adj = offset_shr - (c * strides[1]) + - (n * (C_padded - 1) * D * H * W) - c % blk_size; + host_->mov(tmp_reg, + elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes) + : offset_adj); } template -void jit_uni_binary_injector_t::calculate_mb_sp_nspc( +void jit_uni_binary_injector_t::calculate_mb_sp_nspc_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const { // offset = nDHWC + dHWC + hWC + wC + c // mb_sp_off = nDHW + dHW + hW + w @@ -1130,7 +1156,23 @@ void jit_uni_binary_injector_t::calculate_mb_sp_nspc( } template -void jit_uni_binary_injector_t::calculate_mb_sp_cspn( +void jit_uni_binary_injector_t::calculate_mb_sp_nspc_partial( + const dim_t *strides, const std::size_t offset, + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const { + // offset = nDHWC + dHWC + hWC + wC + c + // mb_sp_off = nDHW + dHW + hW + w + // mb_sp_off = offset / C + const auto C = rhs_arg_static_params_.dst_d.padded_dims()[1]; + const auto offset_adj = (offset >> math::ilog2q(types::data_type_size( + rhs_arg_static_params_.dst_d.data_type()))) + / C; + host_->mov(tmp_reg, + elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes) + : offset_adj); +} + +template +void jit_uni_binary_injector_t::calculate_mb_sp_cspn_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const { // offset = cDHWN + dHWN + hWN + wN + n // mb_sp_off = dHWN + hWN + wN + n @@ -1146,13 +1188,28 @@ void jit_uni_binary_injector_t::calculate_mb_sp_cspn( host_->mov(rax, rdx); } +template +void jit_uni_binary_injector_t::calculate_mb_sp_cspn_partial( + const dim_t *strides, const std::size_t offset, + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const { + // offset = cDHWN + dHWN + hWN + wN + n + // mb_sp_off = dHWN + hWN + wN + n + // mb_sp_off = offset % stride_c + const auto offset_adj = (offset >> math::ilog2q(types::data_type_size( + rhs_arg_static_params_.dst_d.data_type()))) + % strides[1]; + host_->mov(tmp_reg, + elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes) + : offset_adj); +} + template void jit_uni_binary_injector_t::append_mb_w_offset( const std::map &vmm_idx_to_out_addr, const std::map &vmm_idx_to_out_reg, const std::map &vmm_idx_to_out_elem_off_val, int vmm_idx, const Xbyak::Reg64 &addr_reg, const Xbyak::Reg64 &tmp_reg, - std::size_t elem_size_bytes) const { + std::size_t elem_size_bytes, bool is_first) const { const auto it_out_addr = vmm_idx_to_out_addr.find(vmm_idx); const auto it_out_reg = vmm_idx_to_out_reg.find(vmm_idx); @@ -1161,61 +1218,86 @@ void jit_uni_binary_injector_t::append_mb_w_offset( const bool is_out_reg = it_out_reg != vmm_idx_to_out_reg.end(); if (is_out_addr || is_out_reg) { - assert(rhs_arg_static_params_.is_dst_orig_set() - && "dst base addr offset not set"); Xbyak::Address out_addr = is_out_addr ? it_out_addr->second : host_->ptr[it_out_reg->second]; const auto it_off_val = vmm_idx_to_out_elem_off_val.find(vmm_idx); - calculate_no_broadcast(out_addr, - it_off_val != vmm_idx_to_out_elem_off_val.end() - ? it_off_val->second - : 0, - tmp_reg); - - const auto rax = host_->rax; - const auto rdx = host_->rdx; - const auto r8 = host_->r8; - const auto r9 = host_->r9; - - const injector_utils::conditional_register_preserve_guard_t - register_guard {is_out_reg ? utils::one_of( - it_out_reg->second, rax, rdx, r8, r9) - : false, - host_, {it_out_reg->second}}; + const auto &addr_cache_reg = rhs_arg_static_params_.rhs_addr_cache_reg; const auto dst_d = rhs_arg_static_params_.dst_d; const auto strides = dst_d.blocking_desc().strides; const auto layout = injector_utils::get_layout_type(dst_d); - switch (layout) { - case injector_utils::layout_t::ncsp: - calculate_mb_w_ncsp(strides, tmp_reg); - break; - case injector_utils::layout_t::c_blocked: - calculate_mb_w_blocked(strides, tmp_reg); - break; - case injector_utils::layout_t::nspc: - calculate_mb_w_nspc(strides, tmp_reg); - break; - case injector_utils::layout_t::cspn: - calculate_mb_w_cspn(strides, tmp_reg); - break; - default: assert(!"Unknown layout"); - } + if (is_first) { + calculate_no_broadcast_base(out_addr, tmp_reg); + + const auto rax = host_->rax; + const auto rdx = host_->rdx; + const auto r8 = host_->r8; + const auto r9 = host_->r9; + + const injector_utils::conditional_register_preserve_guard_t + register_guard {is_out_reg + ? utils::one_of(it_out_reg->second, rax, + rdx, r8, r9) + : false, + host_, {it_out_reg->second}}; + + switch (layout) { + case injector_utils::layout_t::ncsp: + calculate_mb_w_ncsp_base(strides, tmp_reg); + break; + case injector_utils::layout_t::c_blocked: + calculate_mb_w_blocked_base(strides, tmp_reg); + break; + case injector_utils::layout_t::nspc: + calculate_mb_w_nspc_base(strides, tmp_reg); + break; + case injector_utils::layout_t::cspn: + calculate_mb_w_cspn_base(strides, tmp_reg); + break; + default: assert(!"Unknown layout"); + } - if (elem_size_bytes == 1) { - host_->add(addr_reg, rax); + if (elem_size_bytes == 1) { + host_->add(addr_reg, rax); + } else { + const int shift_val = std::log2(elem_size_bytes); + host_->mov(tmp_reg, rax); + host_->sal(tmp_reg, shift_val); + host_->add(addr_reg, tmp_reg); + } + host_->mov(addr_cache_reg, addr_reg); } else { - const int shift_val = std::log2(elem_size_bytes); - host_->mov(tmp_reg, rax); - host_->sal(tmp_reg, shift_val); + host_->mov(addr_reg, addr_cache_reg); + } + + if (it_off_val != vmm_idx_to_out_elem_off_val.end()) { + switch (layout) { + case injector_utils::layout_t::ncsp: + calculate_mb_w_ncsp_partial(strides, it_off_val->second, + tmp_reg, elem_size_bytes); + break; + case injector_utils::layout_t::c_blocked: + calculate_mb_w_blocked_partial(strides, it_off_val->second, + tmp_reg, elem_size_bytes); + break; + case injector_utils::layout_t::nspc: + calculate_mb_w_nspc_partial(strides, it_off_val->second, + tmp_reg, elem_size_bytes); + break; + case injector_utils::layout_t::cspn: + calculate_mb_w_cspn_partial(strides, it_off_val->second, + tmp_reg, elem_size_bytes); + break; + default: assert(!"Unknown layout"); + } host_->add(addr_reg, tmp_reg); } } } template -void jit_uni_binary_injector_t::calculate_mb_w_ncsp( +void jit_uni_binary_injector_t::calculate_mb_w_ncsp_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const { // offset = (n * stride_n) + (c * stride_c) + (d * stride_d) + (h * stride_h) + (w * stride_w) // mb_w_off = (n * (stride_n/(C*D*H))) + (w * stride_w) @@ -1275,15 +1357,46 @@ void jit_uni_binary_injector_t::calculate_mb_w_ncsp( } template -void jit_uni_binary_injector_t::calculate_mb_w_blocked( +void jit_uni_binary_injector_t::calculate_mb_w_ncsp_partial( + const dim_t *strides, const std::size_t offset, + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const { + // offset = (n * stride_n) + (c * stride_c) + (d * stride_d) + (h * stride_h) + (w * stride_w) + // mb_w_off = (n * (stride_n/(C*D*H))) + (w * stride_w) + const auto dst_d = rhs_arg_static_params_.dst_d; + const auto ndims = dst_d.ndims(); + const auto C_padded = dst_d.padded_dims()[1]; + const auto D = (ndims >= 5) ? dst_d.dims()[ndims - 3] : 1; + const auto H = (ndims >= 4) ? dst_d.dims()[ndims - 2] : 1; + + const auto offset_shr = offset >> math::ilog2q(types::data_type_size( + rhs_arg_static_params_.dst_d.data_type())); + const auto n = offset_shr / strides[0]; + const auto w = (offset_shr % strides[ndims - 2]) / strides[ndims - 1]; + const auto offset_adj = (n * (strides[0] / (C_padded * D * H))) + + (w * strides[ndims - 1]); + host_->mov(tmp_reg, + elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes) + : offset_adj); +} + +template +void jit_uni_binary_injector_t::calculate_mb_w_blocked_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const { // mb_w_off = (n * (stride_n/(C*D*H))) + (w * stride_w) // output = rax - calculate_mb_sp_ncsp(strides, tmp_reg); + calculate_mb_sp_ncsp_base(strides, tmp_reg); +} + +template +void jit_uni_binary_injector_t::calculate_mb_w_blocked_partial( + const dim_t *strides, const std::size_t offset, + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const { + // mb_w_off = (n * (stride_n/(C*D*H))) + (w * stride_w) + calculate_mb_w_ncsp_partial(strides, offset, tmp_reg, elem_size_bytes); } template -void jit_uni_binary_injector_t::calculate_mb_w_nspc( +void jit_uni_binary_injector_t::calculate_mb_w_nspc_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const { // offset = nDHWC + dHWC + hWC + wC + c // mb_w_off = nW + w @@ -1335,7 +1448,28 @@ void jit_uni_binary_injector_t::calculate_mb_w_nspc( } template -void jit_uni_binary_injector_t::calculate_mb_w_cspn( +void jit_uni_binary_injector_t::calculate_mb_w_nspc_partial( + const dim_t *strides, const std::size_t offset, + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const { + // offset = nDHWC + dHWC + hWC + wC + c + // mb_w_off = nW + w + const auto dst_d = rhs_arg_static_params_.dst_d; + const auto ndims = dst_d.ndims(); + const auto W = (ndims >= 3) ? dst_d.dims()[ndims - 1] : 1; + + const auto offset_shr = offset >> math::ilog2q(types::data_type_size( + rhs_arg_static_params_.dst_d.data_type())); + const auto n = offset_shr / strides[0]; + const auto w = (offset_shr % strides[ndims >= 4 ? ndims - 2 : 0]) + / strides[ndims - 1]; + const auto offset_adj = n * W + w; + host_->mov(tmp_reg, + elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes) + : offset_adj); +} + +template +void jit_uni_binary_injector_t::calculate_mb_w_cspn_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const { // offset = cDHWN + dHWN + hWN + wN + n // mb_w_off = wN + n @@ -1363,13 +1497,29 @@ void jit_uni_binary_injector_t::calculate_mb_w_cspn( } } +template +void jit_uni_binary_injector_t::calculate_mb_w_cspn_partial( + const dim_t *strides, const std::size_t offset, + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const { + // offset = cDHWN + dHWN + hWN + wN + n + // mb_w_off = wN + n + const auto ndims = rhs_arg_static_params_.dst_d.ndims(); + const auto offset_shr = offset >> math::ilog2q(types::data_type_size( + rhs_arg_static_params_.dst_d.data_type())); + const auto offset_adj + = ndims >= 4 ? offset_shr % strides[ndims - 2] : offset_shr; + host_->mov(tmp_reg, + elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes) + : offset_adj); +} + template void jit_uni_binary_injector_t::append_w_offset( const std::map &vmm_idx_to_out_addr, const std::map &vmm_idx_to_out_reg, const std::map &vmm_idx_to_out_elem_off_val, int vmm_idx, const Xbyak::Reg64 &addr_reg, const Xbyak::Reg64 &tmp_reg, - std::size_t elem_size_bytes) const { + std::size_t elem_size_bytes, bool is_first) const { const auto it_out_addr = vmm_idx_to_out_addr.find(vmm_idx); const auto it_out_reg = vmm_idx_to_out_reg.find(vmm_idx); @@ -1378,60 +1528,84 @@ void jit_uni_binary_injector_t::append_w_offset( const bool is_out_reg = it_out_reg != vmm_idx_to_out_reg.end(); if (is_out_addr || is_out_reg) { - assert(rhs_arg_static_params_.is_dst_orig_set() - && "dst base addr offset not set"); Xbyak::Address out_addr = is_out_addr ? it_out_addr->second : host_->ptr[it_out_reg->second]; const auto it_off_val = vmm_idx_to_out_elem_off_val.find(vmm_idx); - calculate_no_broadcast(out_addr, - it_off_val != vmm_idx_to_out_elem_off_val.end() - ? it_off_val->second - : 0, - tmp_reg); - - const auto rax = host_->rax; - const auto rdx = host_->rdx; - const auto r8 = host_->r8; - - const injector_utils::conditional_register_preserve_guard_t - register_guard {is_out_reg ? utils::one_of( - it_out_reg->second, rax, rdx, r8) - : false, - host_, {it_out_reg->second}}; + const auto &addr_cache_reg = rhs_arg_static_params_.rhs_addr_cache_reg; const auto dst_d = rhs_arg_static_params_.dst_d; const auto strides = dst_d.blocking_desc().strides; const auto layout = injector_utils::get_layout_type(dst_d); - switch (layout) { - case injector_utils::layout_t::ncsp: - calculate_w_ncsp(strides, tmp_reg); - break; - case injector_utils::layout_t::c_blocked: - calculate_w_blocked(strides, tmp_reg); - break; - case injector_utils::layout_t::nspc: - calculate_w_nspc(strides, tmp_reg); - break; - case injector_utils::layout_t::cspn: - calculate_w_cspn(strides, tmp_reg); - break; - default: assert(!"Unknown layout"); - } + if (is_first) { + calculate_no_broadcast_base(out_addr, tmp_reg); + + const auto rax = host_->rax; + const auto rdx = host_->rdx; + const auto r8 = host_->r8; + + const injector_utils::conditional_register_preserve_guard_t + register_guard {is_out_reg ? utils::one_of( + it_out_reg->second, rax, rdx, r8) + : false, + host_, {it_out_reg->second}}; + + switch (layout) { + case injector_utils::layout_t::ncsp: + calculate_w_ncsp_base(strides, tmp_reg); + break; + case injector_utils::layout_t::c_blocked: + calculate_w_blocked_base(strides, tmp_reg); + break; + case injector_utils::layout_t::nspc: + calculate_w_nspc_base(strides, tmp_reg); + break; + case injector_utils::layout_t::cspn: + calculate_w_cspn_base(strides, tmp_reg); + break; + default: assert(!"Unknown layout"); + } - if (elem_size_bytes == 1) { - host_->add(addr_reg, rax); + if (elem_size_bytes == 1) { + host_->add(addr_reg, rax); + } else { + const int shift_val = std::log2(elem_size_bytes); + host_->mov(tmp_reg, rax); + host_->sal(tmp_reg, shift_val); + host_->add(addr_reg, tmp_reg); + } + host_->mov(addr_cache_reg, addr_reg); } else { - const int shift_val = std::log2(elem_size_bytes); - host_->mov(tmp_reg, rax); - host_->sal(tmp_reg, shift_val); + host_->mov(addr_reg, addr_cache_reg); + } + + if (it_off_val != vmm_idx_to_out_elem_off_val.end()) { + switch (layout) { + case injector_utils::layout_t::ncsp: + calculate_w_ncsp_partial(strides, it_off_val->second, + tmp_reg, elem_size_bytes); + break; + case injector_utils::layout_t::c_blocked: + calculate_w_blocked_partial(strides, it_off_val->second, + tmp_reg, elem_size_bytes); + break; + case injector_utils::layout_t::nspc: + calculate_w_nspc_partial(strides, it_off_val->second, + tmp_reg, elem_size_bytes); + break; + case injector_utils::layout_t::cspn: + calculate_w_cspn_partial(strides, it_off_val->second, + tmp_reg, elem_size_bytes); + break; + default: assert(!"Unknown layout"); + } host_->add(addr_reg, tmp_reg); } } } template -void jit_uni_binary_injector_t::calculate_w_ncsp( +void jit_uni_binary_injector_t::calculate_w_ncsp_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const { // offset = (n * stride_n) + (c * stride_c) + (d * stride_d) + (h * stride_h) + (w * stride_w) // w_off = w * stride_w @@ -1459,13 +1633,36 @@ void jit_uni_binary_injector_t::calculate_w_ncsp( } template -void jit_uni_binary_injector_t::calculate_w_blocked( +void jit_uni_binary_injector_t::calculate_w_ncsp_partial( + const dim_t *strides, const std::size_t offset, + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const { + // offset = (n * stride_n) + (c * stride_c) + (d * stride_d) + (h * stride_h) + (w * stride_w) + // w_off = w * stride_w + const auto ndims = rhs_arg_static_params_.dst_d.ndims(); + const auto offset_shr = offset >> math::ilog2q(types::data_type_size( + rhs_arg_static_params_.dst_d.data_type())); + const auto w = (offset_shr % strides[ndims - 2]) / strides[ndims - 1]; + const auto offset_adj = w * strides[ndims - 1]; + host_->mov(tmp_reg, + elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes) + : offset_adj); +} + +template +void jit_uni_binary_injector_t::calculate_w_blocked_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const { - calculate_w_ncsp(strides, tmp_reg); + calculate_w_ncsp_base(strides, tmp_reg); } template -void jit_uni_binary_injector_t::calculate_w_nspc( +void jit_uni_binary_injector_t::calculate_w_blocked_partial( + const dim_t *strides, const std::size_t offset, + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const { + calculate_w_ncsp_partial(strides, offset, tmp_reg, elem_size_bytes); +} + +template +void jit_uni_binary_injector_t::calculate_w_nspc_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const { // offset = nDHWC + dHWC + hWC + wC + c // w_off = w @@ -1492,12 +1689,36 @@ void jit_uni_binary_injector_t::calculate_w_nspc( } template -void jit_uni_binary_injector_t::calculate_w_cspn( +void jit_uni_binary_injector_t::calculate_w_nspc_partial( + const dim_t *strides, const std::size_t offset, + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const { + // offset = nDHWC + dHWC + hWC + wC + c + // w_off = w + const auto ndims = rhs_arg_static_params_.dst_d.ndims(); + const auto offset_shr = offset >> math::ilog2q(types::data_type_size( + rhs_arg_static_params_.dst_d.data_type())); + const auto offset_adj + = (offset_shr % strides[ndims - 2]) / strides[ndims - 1]; + host_->mov(tmp_reg, + elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes) + : offset_adj); +} + +template +void jit_uni_binary_injector_t::calculate_w_cspn_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const { // offset = cDHWN + dHWN + hWN + wN + n // w_off = w - // output = rax - calculate_w_nspc(strides, tmp_reg); + calculate_w_nspc_base(strides, tmp_reg); +} + +template +void jit_uni_binary_injector_t::calculate_w_cspn_partial( + const dim_t *strides, const std::size_t offset, + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const { + // offset = cDHWN + dHWN + hWN + wN + n + // w_off = w + calculate_w_nspc_partial(strides, offset, tmp_reg, elem_size_bytes); } template diff --git a/src/cpu/x64/injectors/jit_uni_binary_injector.hpp b/src/cpu/x64/injectors/jit_uni_binary_injector.hpp index fdcd09778a1..a741e1559ad 100644 --- a/src/cpu/x64/injectors/jit_uni_binary_injector.hpp +++ b/src/cpu/x64/injectors/jit_uni_binary_injector.hpp @@ -81,6 +81,8 @@ bool all_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops, * stored inside rhs_addr_reg. * @param rhs_helper_reg - gpr register used as helper for calculations during data * loading phase. + * @param rhs_addr_cache_reg - gpr register used for caching part of calculated + * offset. * @param preserve_gpr_helpers - determines whether gpr registers specified above * should be preserved (pushed to stack and poped back afterwords) between * compute_vector_range calls. @@ -105,40 +107,24 @@ bool all_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops, struct rhs_arg_static_params_t { rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, - const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, - bool preserve_vmm_helper, std::size_t abi_param_offset, - const memory_desc_wrapper &dst_d, std::size_t tail_size = 0u, - bool use_exact_tail_scalar_bcast = false); - rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx, - const Xbyak::Reg64 &rhs_addr_reg, - const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, + const Xbyak::Reg64 &rhs_helper_reg, + const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers, bool preserve_vmm_helper, std::size_t abi_param_offset, std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d, std::size_t tail_size = 0u, bool use_exact_tail_scalar_bcast = false); rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, - const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, - bool preserve_vmm_helper, std::size_t abi_param_offset, - const memory_desc_wrapper &dst_d, std::size_t tail_size, - const Xbyak::Opmask &tail_opmask, bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx = 0); - rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx, - const Xbyak::Reg64 &rhs_addr_reg, - const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, + const Xbyak::Reg64 &rhs_helper_reg, + const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers, bool preserve_vmm_helper, std::size_t abi_param_offset, std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d, std::size_t tail_size, const Xbyak::Opmask &tail_opmask, bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx = 0); rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, - const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, - bool preserve_vmm_helper, std::size_t abi_param_offset, - const memory_desc_wrapper &dst_d, std::size_t tail_size, - const Xbyak::Opmask &tail_opmask, const Xbyak::Reg64 ®_tail_size, - bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx = 0); - rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx, - const Xbyak::Reg64 &rhs_addr_reg, - const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, + const Xbyak::Reg64 &rhs_helper_reg, + const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers, bool preserve_vmm_helper, std::size_t abi_param_offset, std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d, std::size_t tail_size, const Xbyak::Opmask &tail_opmask, @@ -146,11 +132,11 @@ struct rhs_arg_static_params_t { bool use_exact_tail_scalar_bcast, std::size_t rhs_prelu_helper_vmm_idx = 0); bool is_opmask_set() const noexcept { return is_opmask_set_; } - bool is_dst_orig_set() const noexcept { return is_dst_orig_set_; } mutable std::size_t rhs_dt_helper_vmm_idx = 0; Xbyak::Reg64 rhs_addr_reg; Xbyak::Reg64 rhs_helper_reg; + Xbyak::Reg64 rhs_addr_cache_reg; bool preserve_gpr_helpers; bool preserve_vmm_helper; std::size_t abi_param_offset; @@ -167,15 +153,15 @@ struct rhs_arg_static_params_t { private: rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg, - const Xbyak::Reg64 &rhs_helper_reg, bool preserve_gpr_helpers, + const Xbyak::Reg64 &rhs_helper_reg, + const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers, bool preserve_vmm_helper, std::size_t abi_param_offset, std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d, std::size_t tail_size, const Xbyak::Opmask &tail_opmask, bool use_exact_tail_scalar_bcast, const Xbyak::Reg64 ®_tail_size, - bool is_opmask_set, bool is_dst_orig_set); + bool is_opmask_set); bool is_opmask_set_; - bool is_dst_orig_set_; }; /* @@ -227,41 +213,9 @@ enum class tail_lode_mode_t { STATIC, DYNAMIC, DEFAULT }; * @param vmm_idx_to_out_reg - vmm mapped to register containing address of destination * with offset, used to calculate offset in no_broadcast strategy, but also in other * strategies whose calculations are based on no_broadcast strategy. - * @param vmm_idx_to_out_elem_off_addr - vmm mapped to offset in elements stored under - * memory address intended to use in no_broadcast strategy. - * @param vmm_idx_to_out_elem_off_addr - vmm mapped to offset in elements stored under - * memory address intended to use in no_broadcast strategy. * @param vmm_idx_to_out_elem_off_val - vmm mapped to offset in elements passed as raw - * value intended to use in no_broadcast strategy - * @param vmm_idx_to_out_off_oprnd - vmm mapped to offset in elements inside operand - * intended to use in no_broadcast strategy - * @param vmm_idx_to_oc_elem_off_addr - vmm mapped to output channel offset in elements - * stored under memory address intended to use in per_oc broadcast strategies. - * @param vmm_idx_to_oc_elem_off_val - vmm mapped to output channel offset in elements - * passed as raw value intended to use in per_oc broadcast strategies. - * @param vmm_idx_to_oc_off_oprnd - vmm mapped to output channel offset in elements inside - * operand intended to use in per_oc broadcast strategies. - * @param vmm_idx_to_sp_elem_off_addr - vmm mapped to proper output spatial offset in - * elements stored under memory address intended to use in per_mb_spatial strategies. - * @param vmm_idx_to_sp_elem_off_val - vmm mapped to proper output spatial offset in - * elements passed as raw value intended to use in per_mb_spatial strategies. - * @param vmm_idx_to_sp_off_oprnd - vmm mapped to proper output spatial offset in - * elements inside operand intended to use in per_mb_spatial strategies. - * @param vmm_idx_to_mb_w_elem_off_addr - vmm mapped to proper output last dim - * per first dim offset in elements stored under memory address intended to use - * in per_mb_w strategies. - * @param vmm_idx_to_mb_w_elem_off_val - vmm mapped to proper output last dim - * per first dim offset in elements passed as raw value intended to use in - * per_mb_w strategies. - * @param vmm_idx_to_mb_w_off_oprnd - vmm mapped to proper output last dim - * per first dim offset in elements inside operand intended to use in per_mb_w - * strategies. - * @param vmm_idx_to_w_elem_off_addr - vmm mapped to proper output last dim - * offset in elements stored under memory address intended to use in per_w strategy. - * @param vmm_idx_to_w_elem_off_val - vmm mapped to proper output last dim - * offset in elements passed as raw value intended to use in per_w strategy. - * @param vmm_idx_to_w_off_oprnd - vmm mapped to proper output last dim offset - * in elements inside operand intended to use in per_w strategy. + * value intended to use in no_broadcast strategy, but also in other + * strategies whose calculations are based on no_broadcast strategy. * @param vmm_tail_idx - vmm indices that contains data don't fill the whole vector (tail). * @param is_dynamic_tail_load - determines whether to load with tail in * runtime (based on the value from reg_tail_size or opmask) or based on given @@ -271,26 +225,7 @@ enum class tail_lode_mode_t { STATIC, DYNAMIC, DEFAULT }; struct rhs_arg_dynamic_params_t { std::map vmm_idx_to_out_addr; std::map vmm_idx_to_out_reg; - - std::map vmm_idx_to_out_elem_off_addr; std::map vmm_idx_to_out_elem_off_val; - std::map vmm_idx_to_out_off_oprnd; - - std::map vmm_idx_to_oc_elem_off_addr; - std::map vmm_idx_to_oc_elem_off_val; - std::map vmm_idx_to_oc_off_oprnd; - - std::map vmm_idx_to_sp_elem_off_addr; - std::map vmm_idx_to_sp_elem_off_val; - std::map vmm_idx_to_sp_off_oprnd; - - std::map vmm_idx_to_mb_w_elem_off_addr; - std::map vmm_idx_to_mb_w_elem_off_val; - std::map vmm_idx_to_mb_w_off_oprnd; - - std::map vmm_idx_to_w_elem_off_addr; - std::map vmm_idx_to_w_elem_off_val; - std::map vmm_idx_to_w_off_oprnd; std::unordered_set vmm_tail_idx_; tail_lode_mode_t tail_load_mode = tail_lode_mode_t::DEFAULT; @@ -375,7 +310,8 @@ class jit_uni_binary_injector_t { Xbyak::Address prepare_rhs_arg_addr(std::size_t vmm_idx, std::size_t rhs_arg_idx, const dnnl_post_ops::entry_t &post_op, const rhs_arg_dynamic_params_t &rhs_arg_params, - const broadcasting_strategy_t rhs_broadcasting_strategy) const; + const broadcasting_strategy_t rhs_broadcasting_strategy, + bool is_first) const; /* * Loads data and applies particular binary operation. */ @@ -386,86 +322,129 @@ class jit_uni_binary_injector_t { /* * Helper functions responsible for preparing rhs tensor slice address. */ - void append_offset_from_operand( - const std::map &vmm_idx_to_elem_addr_off, - int vmm_idx, const Xbyak::Reg64 &addr_reg, - const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const; - void append_offset_under_mem_addr( - const std::map &vmm_idx_to_elem_addr_off, - int vmm_idx, const Xbyak::Reg64 &addr_reg, - const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const; - void append_value_offset( - const std::map &vmm_idx_to_elem_val_off, int vmm_idx, - const Xbyak::Reg64 &addr_reg, std::size_t elem_size_bytes) const; - void append_no_broadcast_offset( const std::map &vmm_idx_to_out_addr, const std::map &vmm_idx_to_out_reg, const std::map &vmm_idx_to_out_elem_off_val, int vmm_idx, const Xbyak::Reg64 &addr_reg, - const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const; - void calculate_no_broadcast(Xbyak::Address addr, std::size_t offset, - const Xbyak::Reg64 &out_reg) const; + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes, + bool is_first) const; + void calculate_no_broadcast_base( + Xbyak::Address addr, const Xbyak::Reg64 &out_reg) const; + void calculate_no_broadcast_partial(const std::size_t offset, + const Xbyak::Reg64 &out_reg, std::size_t elem_size_bytes) const; void append_oc_offset( const std::map &vmm_idx_to_out_addr, const std::map &vmm_idx_to_out_reg, const std::map &vmm_idx_to_out_elem_off_val, int vmm_idx, const Xbyak::Reg64 &addr_reg, - const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const; - void calculate_oc_ncsp( + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes, + bool is_first) const; + void calculate_oc_ncsp_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; - void calculate_oc_blocked( + void calculate_oc_ncsp_partial(const dim_t *strides, + const std::size_t offset, const Xbyak::Reg64 &tmp_reg, + std::size_t elem_size_bytes) const; + void calculate_oc_blocked_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; - void calculate_oc_nspc( + void calculate_oc_blocked_partial(const dim_t *strides, + const std::size_t offset, const Xbyak::Reg64 &tmp_reg, + std::size_t elem_size_bytes) const; + void calculate_oc_nspc_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; - void calculate_oc_cspn( + void calculate_oc_nspc_partial(const dim_t *strides, + const std::size_t offset, const Xbyak::Reg64 &tmp_reg, + std::size_t elem_size_bytes) const; + void calculate_oc_cspn_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; + void calculate_oc_cspn_partial(const dim_t *strides, + const std::size_t offset, const Xbyak::Reg64 &tmp_reg, + std::size_t elem_size_bytes) const; void append_mb_sp_offset( const std::map &vmm_idx_to_out_addr, const std::map &vmm_idx_to_out_reg, const std::map &vmm_idx_to_out_elem_off_val, int vmm_idx, const Xbyak::Reg64 &addr_reg, - const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const; - void calculate_mb_sp_ncsp( + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes, + bool is_first) const; + void calculate_mb_sp_ncsp_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; - void calculate_mb_sp_blocked( + void calculate_mb_sp_ncsp_partial(const dim_t *strides, + const std::size_t offset, const Xbyak::Reg64 &tmp_reg, + std::size_t elem_size_bytes) const; + void calculate_mb_sp_blocked_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; - void calculate_mb_sp_nspc( + void calculate_mb_sp_blocked_partial(const dim_t *strides, + const std::size_t offset, const Xbyak::Reg64 &tmp_reg, + std::size_t elem_size_bytes) const; + void calculate_mb_sp_nspc_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; - void calculate_mb_sp_cspn( + void calculate_mb_sp_nspc_partial(const dim_t *strides, + const std::size_t offset, const Xbyak::Reg64 &tmp_reg, + std::size_t elem_size_bytes) const; + void calculate_mb_sp_cspn_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; + void calculate_mb_sp_cspn_partial(const dim_t *strides, + const std::size_t offset, const Xbyak::Reg64 &tmp_reg, + std::size_t elem_size_bytes) const; void append_mb_w_offset( const std::map &vmm_idx_to_out_addr, const std::map &vmm_idx_to_out_reg, const std::map &vmm_idx_to_out_elem_off_val, int vmm_idx, const Xbyak::Reg64 &addr_reg, - const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const; - void calculate_mb_w_ncsp( + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes, + bool is_first) const; + void calculate_mb_w_ncsp_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; - void calculate_mb_w_blocked( + void calculate_mb_w_ncsp_partial(const dim_t *strides, + const std::size_t offset, const Xbyak::Reg64 &tmp_reg, + std::size_t elem_size_bytes) const; + void calculate_mb_w_blocked_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; - void calculate_mb_w_nspc( + void calculate_mb_w_blocked_partial(const dim_t *strides, + const std::size_t offset, const Xbyak::Reg64 &tmp_reg, + std::size_t elem_size_bytes) const; + void calculate_mb_w_nspc_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; - void calculate_mb_w_cspn( + void calculate_mb_w_nspc_partial(const dim_t *strides, + const std::size_t offset, const Xbyak::Reg64 &tmp_reg, + std::size_t elem_size_bytes) const; + void calculate_mb_w_cspn_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; + void calculate_mb_w_cspn_partial(const dim_t *strides, + const std::size_t offset, const Xbyak::Reg64 &tmp_reg, + std::size_t elem_size_bytes) const; void append_w_offset( const std::map &vmm_idx_to_out_addr, const std::map &vmm_idx_to_out_reg, const std::map &vmm_idx_to_out_elem_off_val, int vmm_idx, const Xbyak::Reg64 &addr_reg, - const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const; - void calculate_w_ncsp( + const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes, + bool is_first) const; + void calculate_w_ncsp_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; - void calculate_w_blocked( + void calculate_w_ncsp_partial(const dim_t *strides, + const std::size_t offset, const Xbyak::Reg64 &tmp_reg, + std::size_t elem_size_bytes) const; + void calculate_w_blocked_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; - void calculate_w_nspc( + void calculate_w_blocked_partial(const dim_t *strides, + const std::size_t offset, const Xbyak::Reg64 &tmp_reg, + std::size_t elem_size_bytes) const; + void calculate_w_nspc_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; - void calculate_w_cspn( + void calculate_w_nspc_partial(const dim_t *strides, + const std::size_t offset, const Xbyak::Reg64 &tmp_reg, + std::size_t elem_size_bytes) const; + void calculate_w_cspn_base( const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; + void calculate_w_cspn_partial(const dim_t *strides, + const std::size_t offset, const Xbyak::Reg64 &tmp_reg, + std::size_t elem_size_bytes) const; template typename std::enable_if::value diff --git a/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp index 31d9d74645c..3baab5bd692 100644 --- a/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp @@ -59,7 +59,7 @@ jit_avx2_1x1_conv_kernel_f32::jit_avx2_1x1_conv_kernel_f32( const size_t tail_size = jcp.oc_without_padding % isa_simd_width_; rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, r13, r14, - preserve_gpr, preserve_vmm, + r15, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md), tail_size, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_avx2_conv_kernel_f32.cpp b/src/cpu/x64/jit_avx2_conv_kernel_f32.cpp index 14430e66636..aee5039dff9 100644 --- a/src/cpu/x64/jit_avx2_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_avx2_conv_kernel_f32.cpp @@ -56,7 +56,7 @@ jit_avx2_conv_fwd_kernel_f32::jit_avx2_conv_fwd_kernel_f32( const size_t tail_size = jcp.oc_without_padding % isa_simd_width_; rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, r13, r14, - preserve_gpr, preserve_vmm, + r15, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md), tail_size, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp b/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp index 7f12a7ab049..e5df8f2a452 100644 --- a/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp @@ -60,7 +60,7 @@ jit_avx512_common_1x1_conv_kernel::jit_avx512_common_1x1_conv_kernel( static constexpr bool use_exact_tail_scalar_bcast = true; const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, - r14, r15, preserve_gpr, preserve_vmm, + r14, r15, r12, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md), tail_size, k_load_dim_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_avx512_common_conv_kernel.cpp b/src/cpu/x64/jit_avx512_common_conv_kernel.cpp index 86b6a0a95a2..5ba016646f7 100644 --- a/src/cpu/x64/jit_avx512_common_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_common_conv_kernel.cpp @@ -101,7 +101,7 @@ _jit_avx512_common_conv_fwd_kernel::_jit_avx512_common_conv_fwd_kernel( static constexpr bool use_exact_tail_scalar_bcast = false; const binary_injector::rhs_arg_static_params_t rhs_args_static_params { - helper_vmm_idx, reg_tmp, r15, preserve_gpr, preserve_vmm, + helper_vmm_idx, reg_tmp, r15, r14, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md), tail_size, postops_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp index 21d7c73792f..43a014955d5 100644 --- a/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp @@ -46,13 +46,14 @@ jit_avx512_core_amx_1x1_fwd_kernel_t::jit_avx512_core_amx_1x1_fwd_kernel_t( using namespace binary_injector; const auto &rhs_addr_reg = bin_injector_helper_reg_1; const auto &rhs_helper_reg = bin_injector_helper_reg_2; + const auto &rhs_addr_cache_reg = bin_injector_helper_reg_3; static constexpr bool preserve_gpr = false; static constexpr bool preserve_vmm = false; const size_t tail_size = jcp.oc_without_padding % isa_simd_width_; static constexpr bool use_exact_tail_scalar_bcast = true; const rhs_arg_static_params_t rhs_arg_static_params {31, rhs_addr_reg, - rhs_helper_reg, preserve_gpr, preserve_vmm, + rhs_helper_reg, rhs_addr_cache_reg, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md), tail_size, ktail_mask, use_exact_tail_scalar_bcast}; @@ -146,7 +147,8 @@ void jit_avx512_core_amx_1x1_fwd_kernel_t::interleave_store() { const injector_utils::conditional_register_preserve_guard_t cond_register_guard(jcp.with_binary, this, {bin_injector_helper_reg_1, - bin_injector_helper_reg_2}); + bin_injector_helper_reg_2, + bin_injector_helper_reg_3}); const int wsp_row_offset = jcp.typesize_acc * (osb * jcp.nb_oc_blocking * jcp.max_width * jcp.oc_block + ocb * jcp.max_width * jcp.oc_block diff --git a/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp index a48225bac2d..0774d67178b 100644 --- a/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp @@ -107,6 +107,7 @@ struct jit_avx512_core_amx_1x1_fwd_kernel_t : public jit_generator { const Xbyak::Reg64 bin_injector_helper_reg_1 = r14; const Xbyak::Reg64 bin_injector_helper_reg_2 = r15; + const Xbyak::Reg64 bin_injector_helper_reg_3 = r11; const Xbyak::Opmask ktail_mask = k2; diff --git a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp index ee5838fd4d6..4908eea8e6b 100644 --- a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp @@ -1057,13 +1057,15 @@ jit_avx512_core_amx_fwd_kernel_t::jit_avx512_core_amx_fwd_kernel_t( using namespace binary_injector; const auto &rhs_addr_reg = bin_injector_helper_reg_1; const auto &rhs_helper_reg = bin_injector_helper_reg_2; + const auto &rhs_addr_cache_reg = bin_injector_helper_reg_3; static constexpr bool preserve_gpr = false; static constexpr bool preserve_vmm = false; const size_t tail_size = jcp.oc_without_padding % isa_simd_width_; static constexpr bool use_exact_tail_scalar_bcast = true; const binary_injector::rhs_arg_static_params_t rhs_arg_static_params { - 31, rhs_addr_reg, rhs_helper_reg, preserve_gpr, preserve_vmm, + 31, rhs_addr_reg, rhs_helper_reg, rhs_addr_cache_reg, + preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md), tail_size, ktail_mask, use_exact_tail_scalar_bcast}; @@ -1603,7 +1605,8 @@ void jit_avx512_core_amx_fwd_kernel_t::store_output(int width, int tail, const injector_utils::conditional_register_preserve_guard_t cond_register_guard(jcp.with_binary, this, {bin_injector_helper_reg_1, - bin_injector_helper_reg_2}); + bin_injector_helper_reg_2, + bin_injector_helper_reg_3}); for (int tw = 0; tw < width && do_store; tw++) { // height diff --git a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp index 2e04f8e6921..1fb45f179bc 100644 --- a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp @@ -341,6 +341,7 @@ struct jit_avx512_core_amx_fwd_kernel_t : public jit_generator { const Xbyak::Reg64 bin_injector_helper_reg_1 = r14; const Xbyak::Reg64 bin_injector_helper_reg_2 = r15; + const Xbyak::Reg64 bin_injector_helper_reg_3 = r11; const Xbyak::Reg64 reg_d_weights = reg_zp_compensation; const Xbyak::Reg64 reg_d_bias = reg_src_zero_point; diff --git a/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp index ad460bb30e2..2beeda3b321 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp @@ -56,7 +56,7 @@ jit_avx512_core_bf16_1x1_conv_kernel::jit_avx512_core_bf16_1x1_conv_kernel( static constexpr bool use_exact_tail_scalar_bcast = true; const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, - r14, r15, preserve_gpr, preserve_vmm, + r14, r15, r12, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md), tail_size, k_load_dim_tail_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp index 312a05b3089..be7f8b2ed0a 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp @@ -113,7 +113,7 @@ _jit_avx512_core_bf16_fwd_kernel::_jit_avx512_core_bf16_fwd_kernel( static constexpr bool use_exact_tail_scalar_bcast = true; const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, - r14, r15, preserve_gpr, preserve_vmm, + r14, r15, r12, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md), tail_size, postops_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp index 155678acb0b..ea82810002a 100644 --- a/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2021 Intel Corporation +* Copyright 2019-2022 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ jit_avx512_dw_conv_fwd_kernel_bf16::jit_avx512_dw_conv_fwd_kernel_bf16( % (cpu_isa_traits::vlen / sizeof(float)); const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, - r14, r15, preserve_gpr, preserve_vmm, + r14, r15, r12, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md), tail_size, k_oc_tail_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp index 03e441c5aa9..11717077a7f 100644 --- a/src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp @@ -591,7 +591,7 @@ void jit_avx512_fork_dw_conv_fwd_kernel_bf16::generate() { % (cpu_isa_traits::vlen / sizeof(float)); static constexpr bool use_exact_tail_scalar_bcast = false; const binary_injector::rhs_arg_static_params_t rhs_sp { - helper_vmm_idx, r10, r11, preserve_gpr, + helper_vmm_idx, r10, r11, r12, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(&dst_md_), tail_size, k_oc_tail_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp index 6e90b9b1b6b..272b5cf1b42 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp @@ -65,7 +65,7 @@ _jit_avx512_core_x8s8s32x_1x1_conv_kernel:: static constexpr bool use_exact_tail_scalar_bcast = true; const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, - r14, r15, preserve_gpr, preserve_vmm, + r14, r15, r13, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md), tail_size, postops_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp index e9246909850..8c0083b14a8 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp @@ -72,7 +72,7 @@ _jit_avx512_core_x8s8s32x_fwd_kernel::_jit_avx512_core_x8s8s32x_fwd_kernel( static constexpr bool use_exact_tail_scalar_bcast = false; const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, - r14, r15, preserve_gpr, preserve_vmm, + r14, r15, r13, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md), tail_size, postops_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp index 18142e05f83..12596ec8b33 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp @@ -59,7 +59,7 @@ jit_avx512_core_x8s8s32x_deconv_fwd_kernel:: const binary_injector::rhs_arg_static_params_t rhs_sp { static_cast(Xbyak::Xmm(31).getIdx()), this->r14, - this->r15, preserve_gpr, preserve_vmm, + this->r15, this->r13, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md), tail_size, ktail_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_brgemm_post_ops.hpp b/src/cpu/x64/jit_brgemm_post_ops.hpp index a7079caf153..85557e3b009 100644 --- a/src/cpu/x64/jit_brgemm_post_ops.hpp +++ b/src/cpu/x64/jit_brgemm_post_ops.hpp @@ -299,7 +299,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator { const binary_injector::rhs_arg_static_params_t rhs_sp { static_cast(Xbyak::Zmm(28).getIdx()), this->r14, - this->r15, preserve_gpr, preserve_vmm, + this->r15, this->r13, preserve_gpr, preserve_vmm, GET_OFF(ptr_binary_post_ops_rhs), GET_OFF(dst_orig), memory_desc_wrapper(brg.dst_md), static_cast(brg.load_dim % brg.ld_block), diff --git a/src/cpu/x64/jit_gemm_convolution_utils.cpp b/src/cpu/x64/jit_gemm_convolution_utils.cpp index 51c8c860251..11e214e130f 100644 --- a/src/cpu/x64/jit_gemm_convolution_utils.cpp +++ b/src/cpu/x64/jit_gemm_convolution_utils.cpp @@ -66,7 +66,7 @@ struct jit_pp_kernel_t : pp_kernel_t, public jit_generator { static constexpr size_t tail_size = 0; static constexpr bool use_exact_tail_scalar_bcast = false; const binary_injector::rhs_arg_static_params_t rhs_sp { - helper_vmm_idx, r13, r14, preserve_gpr, + helper_vmm_idx, r13, r14, r15, preserve_gpr, preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec), PARAM_OFF(dst_orig), memory_desc_wrapper(pd->dst_md()), tail_size, kreg_rem_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_gemm_inner_product_utils.cpp b/src/cpu/x64/jit_gemm_inner_product_utils.cpp index eaf3739bda5..62772bf8d42 100644 --- a/src/cpu/x64/jit_gemm_inner_product_utils.cpp +++ b/src/cpu/x64/jit_gemm_inner_product_utils.cpp @@ -325,7 +325,7 @@ jit_pp_kernel_t::jit_pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride, // for the OC tail_size = !!tail_size ? tail_size : 1; const binary_injector::rhs_arg_static_params_t rhs_arg_static_params { - helper_vmm_idx, eltwise_reserved_gpr_, r14, preserve_gpr, + helper_vmm_idx, eltwise_reserved_gpr_, r14, r15, preserve_gpr, preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec), PARAM_OFF(dst_orig), dst_md_wrapper, tail_size, opmask_binary, reg_tmp, use_exact_tail_scalar_bcast, prelu_helper_vmm_idx}; diff --git a/src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.cpp b/src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.cpp index fae62b9d006..9f02cd6b950 100644 --- a/src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.cpp +++ b/src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.cpp @@ -85,7 +85,7 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator { static constexpr size_t tail_size = 0; static constexpr bool use_exact_tail_scalar_bcast = false; const binary_injector::rhs_arg_static_params_t rhs_sp { - helper_vmm_idx, r13, r14, preserve_gpr, + helper_vmm_idx, r13, r14, r15, preserve_gpr, preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec), PARAM_OFF(dst_orig), memory_desc_wrapper(pd->dst_md()), tail_size, kreg_rem_mask_short, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp b/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp index 9b208172062..4d9e71c0080 100644 --- a/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2021 Intel Corporation +* Copyright 2017-2022 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -53,7 +53,7 @@ jit_sse41_1x1_conv_kernel_f32::jit_sse41_1x1_conv_kernel_f32( static constexpr bool use_exact_tail_scalar_bcast = false; const binary_injector::rhs_arg_static_params_t rhs_arg_static_params { - helper_vmm_idx, r13, r14, preserve_gpr, preserve_vmm, + helper_vmm_idx, r13, r14, r15, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md), tail_size, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_sse41_conv_kernel_f32.cpp b/src/cpu/x64/jit_sse41_conv_kernel_f32.cpp index 53b546f100e..ec60734e1b9 100644 --- a/src/cpu/x64/jit_sse41_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_sse41_conv_kernel_f32.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2021 Intel Corporation +* Copyright 2017-2022 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,7 +54,7 @@ jit_sse41_conv_fwd_kernel_f32::jit_sse41_conv_fwd_kernel_f32( static constexpr bool use_exact_tail_scalar_bcast = false; const binary_injector::rhs_arg_static_params_t rhs_arg_static_params { - helper_vmm_idx, r14, r15, preserve_gpr, preserve_vmm, + helper_vmm_idx, r14, r15, r12, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md), tail_size, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_uni_binary_kernel.cpp b/src/cpu/x64/jit_uni_binary_kernel.cpp index d7360813ea8..5210c4e1856 100644 --- a/src/cpu/x64/jit_uni_binary_kernel.cpp +++ b/src/cpu/x64/jit_uni_binary_kernel.cpp @@ -127,9 +127,10 @@ void jit_uni_binary_kernel_t::init_post_ops_injector() { reg_elt_inj_table_, elt_inj_opmask_, true /*is_fwd*/, false /*use_dst*/); const binary_injector::rhs_arg_static_params_t rhs_arg_bsp {10, reg_tmp_, - reg_elt_inj_table_, true /*preserve gpr*/, true /*preserve vmm*/, - PARAM_OFF(post_ops_binary_rhs_arg_vec), PARAM_OFF(dst_orig), dst_d, - tail_size_, tail_opmask_, false /*use_exact_tail_scalar_bcast*/}; + reg_elt_inj_table_, r13, true /*preserve gpr*/, + true /*preserve vmm*/, PARAM_OFF(post_ops_binary_rhs_arg_vec), + PARAM_OFF(dst_orig), dst_d, tail_size_, tail_opmask_, + false /*use_exact_tail_scalar_bcast*/}; const binary_injector::static_params_t bsp(this->param1, get_supported_postops_bcast_strategies(), rhs_arg_bsp); diff --git a/src/cpu/x64/jit_uni_dw_conv_kernel_f32.cpp b/src/cpu/x64/jit_uni_dw_conv_kernel_f32.cpp index 01b0aaa0340..ac4f40df5db 100644 --- a/src/cpu/x64/jit_uni_dw_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_uni_dw_conv_kernel_f32.cpp @@ -48,7 +48,7 @@ jit_uni_dw_conv_fwd_kernel_f32::jit_uni_dw_conv_fwd_kernel_f32( const size_t tail_size = jcp.oc_without_padding % (cpu_isa_traits::vlen / sizeof(float)); rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, r14, r15, - preserve_gpr, preserve_vmm, + r12, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md), tail_size, k_oc_tail_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.cpp b/src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.cpp index fdbb5e4a254..9b4dfb4878a 100644 --- a/src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.cpp +++ b/src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.cpp @@ -772,7 +772,7 @@ void jit_uni_fork_dw_conv_fwd_kernel_f32::generate() { % (cpu_isa_traits::vlen / sizeof(float)); static constexpr bool use_exact_tail_scalar_bcast = false; const binary_injector::rhs_arg_static_params_t rhs_sp { - helper_vmm_idx, r10, r11, preserve_gpr, + helper_vmm_idx, r10, r11, r12, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(&dst_md_), tail_size, k_oc_tail_mask, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_uni_i8i8_pooling.cpp b/src/cpu/x64/jit_uni_i8i8_pooling.cpp index 26fe358e1c1..e4ce38dd84d 100644 --- a/src/cpu/x64/jit_uni_i8i8_pooling.cpp +++ b/src/cpu/x64/jit_uni_i8i8_pooling.cpp @@ -263,7 +263,7 @@ struct jit_uni_i8i8_pooling_fwd_ker_t : public jit_generator { static constexpr std::size_t tmp_vmm_injector = 0u; const binary_injector::rhs_arg_static_params_t rhs_sp { - tmp_vmm_injector, r14, r15, preserve_gpr, preserve_vmm, + tmp_vmm_injector, r14, r15, r13, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(*dst_md), c_tail_elems, mask(post_op_tail_opmask_idx_), diff --git a/src/cpu/x64/jit_uni_pool_kernel.cpp b/src/cpu/x64/jit_uni_pool_kernel.cpp index 936e5db1fc9..6e4787ac3e6 100644 --- a/src/cpu/x64/jit_uni_pool_kernel.cpp +++ b/src/cpu/x64/jit_uni_pool_kernel.cpp @@ -63,7 +63,7 @@ jit_uni_pool_kernel::jit_uni_pool_kernel( const binary_injector::rhs_arg_static_params_t rhs_sp { static_cast(this->xmm4.getIdx()), this->r14, - this->r15, preserve_gpr, preserve_vmm, + this->r15, this->r13, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(jpp.tag_kind == jit_memory_tag_kind_t::ncsp ? *(jpp.tmp_md) diff --git a/src/cpu/x64/jit_uni_reduction_kernel.cpp b/src/cpu/x64/jit_uni_reduction_kernel.cpp index 7e42f320a8f..891713143c6 100644 --- a/src/cpu/x64/jit_uni_reduction_kernel.cpp +++ b/src/cpu/x64/jit_uni_reduction_kernel.cpp @@ -156,9 +156,9 @@ void jit_uni_reduction_kernel_t::init_post_ops_injector( const binary_injector::rhs_arg_static_params_t rhs_arg_bsp { static_cast(rhs_dt_helper_vmm_.getIdx()), reg_po_injector_helper_1_, reg_po_injector_helper_2_, - true /*preserve gpr*/, true /*preserve vmm*/, - GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), dst_d, - store_tail_size_, k_tail_store_mask_, + reg_po_injector_helper_3_, true /*preserve gpr*/, + true /*preserve vmm*/, GET_OFF(post_ops_binary_rhs_arg_vec), + GET_OFF(dst_orig), dst_d, store_tail_size_, k_tail_store_mask_, false /*use_exact_tail_scalar_bcast*/}; const binary_injector::static_params_t bsp( reg_param_, get_supported_postops_bcast_strategies(), rhs_arg_bsp); diff --git a/src/cpu/x64/jit_uni_reduction_kernel.hpp b/src/cpu/x64/jit_uni_reduction_kernel.hpp index 238279a1a41..1de627383ec 100644 --- a/src/cpu/x64/jit_uni_reduction_kernel.hpp +++ b/src/cpu/x64/jit_uni_reduction_kernel.hpp @@ -135,6 +135,7 @@ struct jit_uni_reduction_kernel_t : public jit_uni_reduction_kernel_base_t { const Xbyak::Opmask elt_inj_opmask_ = k1; const Xbyak::Reg64 reg_po_injector_helper_1_ = r14; const Xbyak::Reg64 reg_po_injector_helper_2_ = r15; + const Xbyak::Reg64 reg_po_injector_helper_3_ = r12; // post-ops injector does not use avx512_core_bf16 instructions static constexpr cpu_isa_t inject_isa_ diff --git a/src/cpu/x64/jit_uni_resampling_kernel.cpp b/src/cpu/x64/jit_uni_resampling_kernel.cpp index b458d7ae75f..42d6096fcb0 100644 --- a/src/cpu/x64/jit_uni_resampling_kernel.cpp +++ b/src/cpu/x64/jit_uni_resampling_kernel.cpp @@ -57,7 +57,7 @@ jit_uni_resampling_kernel_t::jit_uni_resampling_kernel_t( const binary_injector::rhs_arg_static_params_t rhs_sp { static_cast(vmm_post_op_helper_.getIdx()), r14, r15, - preserve_gpr, preserve_vmm, + r13, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), dst_d, tail_size_, k_tail_mask_, use_exact_tail_scalar_bcast}; diff --git a/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp b/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp index 31274ea3e9e..556b3ab42a6 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp @@ -54,7 +54,7 @@ _jit_uni_x8s8s32x_1x1_conv_kernel::_jit_uni_x8s8s32x_1x1_conv_kernel( using namespace binary_injector; static constexpr bool preserve_gpr = true; static constexpr bool preserve_vmm = true; - rhs_arg_static_params_t rhs_arg_static_params {15, r13, r14, + rhs_arg_static_params_t rhs_arg_static_params {15, r13, r14, r15, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md)}; diff --git a/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp b/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp index 44ce1a37ec9..91b1ef84f4a 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp @@ -67,7 +67,7 @@ _jit_uni_x8s8s32x_fwd_kernel::_jit_uni_x8s8s32x_fwd_kernel( const size_t tail_size = 0; const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, - r13, r14, preserve_gpr, preserve_vmm, + r13, r14, r15, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), memory_desc_wrapper(dst_md), tail_size, true}; const static_params_t static_params { diff --git a/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp b/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp index 2c8b1c9c973..a84f0760047 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp @@ -416,7 +416,7 @@ _jit_uni_x8s8s32x_deconv_fwd_kernelr14, this->r15, preserve_gpr, preserve_vmm, + this->r14, this->r15, this->r13, preserve_gpr, preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), dst_d, tail_size, Xbyak::Opmask(2), use_exact_tail_scalar_bcast}; const binary_injector::static_params_t bsp {this->param1_, rhs_sp};