Skip to content

Commit

Permalink
Implement CPU plugin just-in-time emitter for IsNaN operation (#24808)
Browse files Browse the repository at this point in the history
Closes: #24420
### Details:
 - *item1*
 - *...*

### Tickets:
 - [CVS-137700](https://jira.devtools.intel.com/browse/CVS-137700)
  • Loading branch information
awayzjj authored Jun 24, 2024
1 parent 045735c commit bce4ed1
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,69 @@ void jit_is_inf_emitter::register_table_entries() {
push_arg_entry_of("inf_neg", 0xFF800000, true);
}

/// IS_NAN ///
jit_is_nan_emitter::jit_is_nan_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {
auto isNaN = ov::as_type_ptr<ov::op::v10::IsNaN>(node);
if (isNaN == nullptr) {
OV_CPU_JIT_EMITTER_THROW("Can't cast to ov::op::v10::IsNaN");
}

prepare_table();
}

jit_is_nan_emitter::jit_is_nan_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
prepare_table();
}

size_t jit_is_nan_emitter::get_inputs_count() const { return 1; }

size_t jit_is_nan_emitter::get_aux_vecs_count() const { return 1; }

size_t jit_is_nan_emitter::get_aux_gprs_count() const { return 1; }

std::set<std::vector<element::Type>> jit_is_nan_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32}};
}

void jit_is_nan_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_is_nan_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;

TReg src = TReg(in_vec_idxs[0]);
TReg dst = TReg(out_vec_idxs[0]);
TReg aux = TReg(aux_vec_idxs[0]);

// According to the IEEE standard, NaN values have the odd property that comparisons involving them are always false.
h->fcmeq(dst.s, src.s, src.s);
h->ld1r(aux.s, table_val2("zero"));
h->fcmeq(dst.s, dst.s, aux.s);
// Sets elements in 'dst' to 1.0 where the comparison was true.
h->ld1r(aux.s, table_val2("one"));
h->and_(dst.b16, dst.b16, aux.b16);
}

void jit_is_nan_emitter::register_table_entries() {
// Registers constant values that comply with the IEEE 754 standard.
push_arg_entry_of("one", 0x3f800000, true);
push_arg_entry_of("zero", 0x00000000, true);
}

/// MAX ///
jit_maximum_emitter::jit_maximum_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,33 @@ class jit_hswish_emitter : public jit_emitter {
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_is_nan_emitter : public jit_emitter {
public:
jit_is_nan_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc = ov::element::f32);

jit_is_nan_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node);

size_t get_inputs_count() const override;

size_t get_aux_vecs_count() const override;

size_t get_aux_gprs_count() const override;

static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;

void register_table_entries() override;
};

class jit_maximum_emitter : public jit_emitter {
public:
jit_maximum_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ bool JitEltwiseExecutor::isSupported(
Algorithm::EltwiseGeluTanh,
Algorithm::EltwiseHswish,
Algorithm::EltwiseIsInf,
Algorithm::EltwiseIsNaN,
Algorithm::EltwiseMaximum,
Algorithm::EltwiseMinimum,
Algorithm::EltwiseMish,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
OV_CASE(Algorithm::EltwiseFloor, ov::intel_cpu::aarch64::jit_floor_emitter),
OV_CASE(Algorithm::EltwiseHswish, ov::intel_cpu::aarch64::jit_hswish_emitter),
OV_CASE(Algorithm::EltwiseIsInf, ov::intel_cpu::aarch64::jit_is_inf_emitter),
OV_CASE(Algorithm::EltwiseIsNaN, ov::intel_cpu::aarch64::jit_is_nan_emitter),
OV_CASE(Algorithm::EltwiseMaximum, ov::intel_cpu::aarch64::jit_maximum_emitter),
OV_CASE(Algorithm::EltwiseMinimum, ov::intel_cpu::aarch64::jit_minimum_emitter),
OV_CASE(Algorithm::EltwiseMish, ov::intel_cpu::aarch64::jit_mish_emitter),
Expand Down Expand Up @@ -823,6 +824,7 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
OV_CASE(Algorithm::EltwiseGeluTanh, jit_gelu_tanh_emitter),
OV_CASE(Algorithm::EltwiseHswish, jit_hswish_emitter),
OV_CASE(Algorithm::EltwiseIsInf, jit_is_inf_emitter),
OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter),
OV_CASE(Algorithm::EltwiseMaximum, jit_maximum_emitter),
OV_CASE(Algorithm::EltwiseMinimum, jit_minimum_emitter),
OV_CASE(Algorithm::EltwiseMish, jit_mish_emitter),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ std::string ActivationLayerCPUTest::getPrimitiveType(const utils::ActivationType
(activation_type == utils::ActivationTypes::HSwish) ||
(activation_type == utils::ActivationTypes::IsInf) ||
(activation_type == utils::ActivationTypes::HardSigmoid) ||
(activation_type == utils::ActivationTypes::IsNaN) ||
(activation_type == utils::ActivationTypes::Mish) ||
(activation_type == utils::ActivationTypes::GeluErf) ||
(activation_type == utils::ActivationTypes::GeluTanh) ||
Expand All @@ -190,7 +191,8 @@ std::string ActivationLayerCPUTest::getPrimitiveType(const utils::ActivationType
return "";
}
#endif
if (activation_type == utils::ActivationTypes::Floor) {
if ((activation_type == utils::ActivationTypes::Floor) ||
(activation_type == utils::ActivationTypes::IsNaN)) {
return "ref";
}
return "acl";
Expand Down Expand Up @@ -227,6 +229,7 @@ const std::map<utils::ActivationTypes, std::vector<std::vector<float>>>& activat
{GeluTanh, {{}}},
{SoftSign, {{}}},
{SoftPlus, {{}}},
{IsNaN, {{}}},
};

return activationTypes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ const std::map<ActivationTypes, std::vector<std::vector<float>>> activationTypes
{ActivationTypes::GeluErf, {}},
{ActivationTypes::GeluTanh, {}},
{ActivationTypes::Swish, {{0.4f}}},
{ActivationTypes::IsInf, {}}
{ActivationTypes::IsInf, {}},
{ActivationTypes::IsNaN, {{}}},
};

// List of operations that should be tested also with integer precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ static std::map<ActivationTypes, std::string> activationNames = {
{ActivationTypes::GeluTanh, "GeluTanh"},
{ActivationTypes::SoftSign, "SoftSign"},
{ActivationTypes::IsInf, "IsInf"},
{ActivationTypes::IsNaN, "IsNaN"},
};

typedef std::tuple<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ enum ActivationTypes {
GeluErf,
GeluTanh,
SoftSign,
IsInf
IsInf,
IsNaN,
};

enum MinMaxOpType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "openvino/op/hsigmoid.hpp"
#include "openvino/op/hswish.hpp"
#include "openvino/op/is_inf.hpp"
#include "openvino/op/is_nan.hpp"
#include "openvino/op/log.hpp"
#include "openvino/op/mish.hpp"
#include "openvino/op/negative.hpp"
Expand Down Expand Up @@ -147,6 +148,8 @@ std::shared_ptr<ov::Node> make_activation(const ov::Output<Node>& in,
return std::make_shared<ov::op::v9::SoftSign>(in);
case ov::test::utils::ActivationTypes::IsInf:
return std::make_shared<ov::op::v10::IsInf>(in);
case ov::test::utils::ActivationTypes::IsNaN:
return std::make_shared<ov::op::v10::IsNaN>(in);
default:
OPENVINO_THROW("Can't create layer for this activation type");
}
Expand Down

0 comments on commit bce4ed1

Please sign in to comment.