Skip to content

Commit

Permalink
[CPU] [ARM64] Exp injector & Sigmoid injector
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Mar 26, 2024
1 parent c2d478b commit 6fab6f1
Show file tree
Hide file tree
Showing 6 changed files with 359 additions and 259 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "jit_eltwise_emitters.hpp"

#include <memory>
#include "jit_eltwise_injectors.hpp"
#include "common/utils.hpp"
#include "emitters/utils.hpp"

Expand Down Expand Up @@ -240,115 +241,22 @@ jit_exp_emitter::jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,

size_t jit_exp_emitter::get_inputs_count() const { return 1; }

size_t jit_exp_emitter::get_aux_vecs_count() const { return 4; }
size_t jit_exp_emitter::get_aux_vecs_count() const {
return jit_exp_injector::get_aux_vecs_count();
}

size_t jit_exp_emitter::get_aux_gprs_count() const { return 1; }

void jit_exp_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
jit_exp_injector::emit_impl<dnnl::impl::cpu::aarch64::asimd>(h, host_isa_, entry_map_, exec_prc_, in_vec_idxs, aux_vec_idxs, out_vec_idxs, p_table);
} else {
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_exp_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != ov::element::f32) {
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());
}

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
const TReg vmm_src(in_vec_idxs[0]);
const TReg vmm_dst(out_vec_idxs[0]);
const TReg vmm_aux1(aux_vec_idxs[0]);
const TReg vmm_aux2(aux_vec_idxs[1]);
const TReg vmm_aux0(aux_vec_idxs[2]);

const TReg vmm_mask(aux_vec_idxs[3]);

h->ld1r(vmm_aux0.s, table_val2("exp_ln_flt_max_f"));
h->fmin(vmm_dst.s, vmm_src.s, vmm_aux0.s);
h->ld1r(vmm_aux0.s, table_val2("exp_ln_flt_min_f"));

// get mask of values lower than log(FLT_MIN) to zero them in the output
h->fcmgt(vmm_mask.s, vmm_src.s, vmm_aux0.s);

h->fmax(vmm_dst.s, vmm_dst.s, vmm_aux0.s);
h->mov(vmm_aux1.b16, vmm_dst.b16);

// calculate exp(x)
// fx = x * log2ef + 0.5
h->ld1r(vmm_aux0.s, table_val2("exp_log2ef"));
h->ld1r(vmm_aux2.s, table_val2("half"));
h->fmla(vmm_aux2.s, vmm_dst.s, vmm_aux0.s);

// tmp = floorf(fx)
h->frintm(vmm_aux2.s, vmm_aux2.s);

// keep vmm_src = fx for further computations
h->mov(vmm_dst.b16, vmm_aux2.b16);

// x = x - fx * ln2
h->ld1r(vmm_aux0.s, table_val2("ln2f"));
h->fmls(vmm_aux1.s, vmm_aux2.s, vmm_aux0.s);

// We do not count 2^n here, because n can reach 128 and 2^128 is not
// representable by fp32, so to get around this problem, instead of computing
// 2^n * exp(r) will be counted 2*2^(n-1)*exp(r), because 2^127
// and 2 are numbers representable in fp32.

// compute 2^(n-1)
h->ld1r(vmm_aux0.s, table_val2("one"));
h->fsub(vmm_dst.s, vmm_dst.s, vmm_aux0.s);
h->fcvtzs(vmm_aux2.s, vmm_dst.s);

h->ld1r(vmm_aux0.s, table_val2("exponent_bias"));
h->add(vmm_aux2.s, vmm_aux2.s, vmm_aux0.s);

const int n_mantissa_bits = 23;
h->sqshl(vmm_aux2.s, vmm_aux2.s, n_mantissa_bits);

// set zeroes at those points which were < log(FLT_MIN)
h->and_(vmm_aux2.b16, vmm_mask.b16, vmm_aux2.b16);

// compute polynomial
h->ld1r(vmm_aux0.s, table_val2("exp_pol5"));
h->ld1r(vmm_dst.s, table_val2("exp_pol4"));
h->fmla(vmm_dst.s, vmm_aux1.s, vmm_aux0.s);

h->ld1r(vmm_aux0.s, table_val2("exp_pol3"));
h->fmla(vmm_aux0.s, vmm_dst.s, vmm_aux1.s);

h->ld1r(vmm_dst.s, table_val2("exp_pol2"));
h->fmla(vmm_dst.s, vmm_aux0.s, vmm_aux1.s);

h->ld1r(vmm_aux0.s, table_val2("exp_pol1"));
h->fmla(vmm_aux0.s, vmm_dst.s, vmm_aux1.s);

h->ld1r(vmm_dst.s, table_val2("one"));
h->fmla(vmm_dst.s, vmm_aux0.s, vmm_aux1.s);

// y = y * 2^n
h->fmul(vmm_dst.s, vmm_dst.s, vmm_aux2.s);
h->ld1r(vmm_aux0.s, table_val2("two"));
h->fmul(vmm_dst.s, vmm_dst.s, vmm_aux0.s);
}

void jit_exp_emitter::register_table_entries() {
push_arg_entry_of("exp_ln_flt_max_f", 0x42b17218, true);
push_arg_entry_of("exp_ln_flt_min_f", 0xc2aeac50, true);
push_arg_entry_of("exp_log2ef", 0x3fb8aa3b, true);
push_arg_entry_of("one", 0x3f800000, true);
push_arg_entry_of("two", 0x40000000, true);
push_arg_entry_of("half", 0x3f000000, true);
push_arg_entry_of("ln2f", 0x3f317218, true);
push_arg_entry_of("exponent_bias", 0x0000007f, true);
push_arg_entry_of("exp_pol1", 0x3f7ffffb, true);
push_arg_entry_of("exp_pol2", 0x3efffee3, true);
push_arg_entry_of("exp_pol3", 0x3e2aad40, true);
push_arg_entry_of("exp_pol4", 0x3d2b9d0d, true);
push_arg_entry_of("exp_pol5", 0x3c07cfce, true);
jit_exp_injector::push_entry_map(entry_map_);
}

std::set<std::vector<element::Type>> jit_exp_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
Expand Down Expand Up @@ -737,159 +645,32 @@ jit_sigmoid_emitter::jit_sigmoid_emitter(dnnl::impl::cpu::aarch64::jit_generator
prepare_table();
}

size_t jit_sigmoid_emitter::get_inputs_count() const { return 1; }
size_t jit_sigmoid_emitter::get_inputs_count() const {return 1; }

size_t jit_sigmoid_emitter::get_aux_vecs_count() const { return 5; }
size_t jit_sigmoid_emitter::get_aux_vecs_count() const {
return jit_sigmoid_injector::get_aux_vecs_count();
}

size_t jit_sigmoid_emitter::get_aux_gprs_count() const { return 1; }

void jit_sigmoid_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
jit_sigmoid_injector::emit_impl<dnnl::impl::cpu::aarch64::asimd>(
h,
host_isa_,
entry_map_,
exec_prc_,
in_vec_idxs,
aux_vec_idxs,
out_vec_idxs,
p_table);
} else {
OPENVINO_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_sigmoid_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != ov::element::f32) {
OPENVINO_THROW("unsupported precision: " + exec_prc_.to_string());
}

// TODO: will be refactored
const auto exp_compute_vector_fwd = [&]() {
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
const TReg vmm_src(in_vec_idxs[0]);
const TReg vmm_dst(out_vec_idxs[0]);
const TReg vmm_aux1(aux_vec_idxs[0]);
const TReg vmm_aux2(aux_vec_idxs[1]);
const TReg vmm_aux0(aux_vec_idxs[2]);

const TReg vmm_mask(aux_vec_idxs[3]);

h->ld1r(vmm_aux0.s, table_val2("exp_ln_flt_max_f"));
h->fmin(vmm_dst.s, vmm_src.s, vmm_aux0.s);
h->ld1r(vmm_aux0.s, table_val2("exp_ln_flt_min_f"));

// get mask of values lower than log(FLT_MIN) to zero them in the output
h->fcmgt(vmm_mask.s, vmm_src.s, vmm_aux0.s);

h->fmax(vmm_dst.s, vmm_dst.s, vmm_aux0.s);
h->mov(vmm_aux1.b16, vmm_dst.b16);

// calculate exp(x)
// fx = x * log2ef + 0.5
h->ld1r(vmm_aux0.s, table_val2("exp_log2ef"));
h->ld1r(vmm_aux2.s, table_val2("half"));
h->fmla(vmm_aux2.s, vmm_dst.s, vmm_aux0.s);

// tmp = floorf(fx)
h->frintm(vmm_aux2.s, vmm_aux2.s);

// keep vmm_src = fx for further computations
h->mov(vmm_dst.b16, vmm_aux2.b16);

// x = x - fx * ln2
h->ld1r(vmm_aux0.s, table_val2("ln2f"));
h->fmls(vmm_aux1.s, vmm_aux2.s, vmm_aux0.s);

// We do not count 2^n here, because n can reach 128 and 2^128 is not
// representable by fp32, so to get around this problem, instead of computing
// 2^n * exp(r) will be counted 2*2^(n-1)*exp(r), because 2^127
// and 2 are numbers representable in fp32.

// compute 2^(n-1)
h->ld1r(vmm_aux0.s, table_val2("one"));
h->fsub(vmm_dst.s, vmm_dst.s, vmm_aux0.s);
h->fcvtzs(vmm_aux2.s, vmm_dst.s);

h->ld1r(vmm_aux0.s, table_val2("exponent_bias"));
h->add(vmm_aux2.s, vmm_aux2.s, vmm_aux0.s);

h->sqshl(vmm_aux2.s, vmm_aux2.s, 23);

// set zeroes at those points which were < log(FLT_MIN)
h->and_(vmm_aux2.b16, vmm_mask.b16, vmm_aux2.b16);

// compute polynomial
h->ld1r(vmm_aux0.s, table_val2("exp_pol5"));
h->ld1r(vmm_dst.s, table_val2("exp_pol4"));
h->fmla(vmm_dst.s, vmm_aux1.s, vmm_aux0.s);

h->ld1r(vmm_aux0.s, table_val2("exp_pol3"));
h->fmla(vmm_aux0.s, vmm_dst.s, vmm_aux1.s);

h->ld1r(vmm_dst.s, table_val2("exp_pol2"));
h->fmla(vmm_dst.s, vmm_aux0.s, vmm_aux1.s);

h->ld1r(vmm_aux0.s, table_val2("exp_pol1"));
h->fmla(vmm_aux0.s, vmm_dst.s, vmm_aux1.s);

h->ld1r(vmm_dst.s, table_val2("one"));
h->fmla(vmm_dst.s, vmm_aux0.s, vmm_aux1.s);

// y = y * 2^n
h->fmul(vmm_dst.s, vmm_dst.s, vmm_aux2.s);
h->ld1r(vmm_aux0.s, table_val2("two"));
h->fmul(vmm_dst.s, vmm_dst.s, vmm_aux0.s);
};


using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
const TReg vmm_src(in_vec_idxs[0]);
const TReg vmm_dst(out_vec_idxs[0]);
const TReg vmm_aux1(aux_vec_idxs[0]);
const TReg vmm_aux2(aux_vec_idxs[1]);
//const TReg vmm_aux3(aux_vec_idxs[1]); // <= remove later
const TReg vmm_aux0(aux_vec_idxs[2]);

const TReg vmm_mask(aux_vec_idxs[4]);

// To avoid exp(x) overflow happened at x > logf(FLT_MAX), negate positive,
// compute exp(x), where x <= 0 to get 0 <= exp(x) <= 1 and restore value
// sign at the end. This is possible due to logistic is symmetric function.
// IMPORTANT: we use vmm_mask for the mask as exp_compute does not use it.
// we store the original sign and make x negative
h->eor(vmm_aux0.b16, vmm_aux0.b16, vmm_aux0.b16);
h->fcmgt(vmm_mask.s, vmm_src.s, vmm_aux0.s);

h->ld1r(vmm_aux0.s, table_val2("sign_mask"));
h->orr(vmm_src.b16, vmm_src.b16, vmm_aux0.b16);

exp_compute_vector_fwd();

// dup exp(x)
h->mov(vmm_aux1.b16, vmm_dst.b16);
// (exp(x) + 1)
h->ld1r(vmm_aux0.s, table_val2("one"));
h->fadd(vmm_aux1.s, vmm_aux1.s, vmm_aux0.s);
// y = exp(x) / (exp(x) + 1)
h->fdiv(vmm_dst.s, vmm_dst.s, vmm_aux1.s);

// Now we have to apply the "symmetry" based on original sign
h->ld1r(vmm_aux2.s, table_val2("one"));
h->fsub(vmm_aux2.s, vmm_aux2.s, vmm_dst.s);

h->bsl(vmm_mask.b16, vmm_aux2.b16, vmm_dst.b16);
h->mov(vmm_dst.b16, vmm_mask.b16);
}

void jit_sigmoid_emitter::register_table_entries() {
// jit_exp_emitter
push_arg_entry_of("exp_ln_flt_max_f", 0x42b17218, true);
push_arg_entry_of("exp_ln_flt_min_f", 0xc2aeac50, true);
push_arg_entry_of("exp_log2ef", 0x3fb8aa3b, true);
push_arg_entry_of("one", 0x3f800000, true);
push_arg_entry_of("two", 0x40000000, true);
push_arg_entry_of("half", 0x3f000000, true);
push_arg_entry_of("ln2f", 0x3f317218, true);
push_arg_entry_of("exponent_bias", 0x0000007f, true);
push_arg_entry_of("exp_pol1", 0x3f7ffffb, true);
push_arg_entry_of("exp_pol2", 0x3efffee3, true);
push_arg_entry_of("exp_pol3", 0x3e2aad40, true);
push_arg_entry_of("exp_pol4", 0x3d2b9d0d, true);
push_arg_entry_of("exp_pol5", 0x3c07cfce, true);
jit_exp_injector::push_entry_map(entry_map_);

push_arg_entry_of("sign_mask", 0x80000000, true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,6 @@ class jit_exp_emitter : public jit_emitter {

private:
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};


Expand Down Expand Up @@ -310,9 +307,6 @@ class jit_sigmoid_emitter : public jit_emitter {

private:
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_subtract_emitter : public jit_emitter {
Expand Down
Loading

0 comments on commit 6fab6f1

Please sign in to comment.