Skip to content

Commit

Permalink
[ARM plugin] Activation fixes (openvinotoolkit#490)
Browse files Browse the repository at this point in the history
  • Loading branch information
alvoron authored Dec 5, 2022
1 parent fa868c8 commit ae1f96d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 7 deletions.
46 changes: 40 additions & 6 deletions modules/arm_plugin/src/arm_converter/arm_converter_activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#include <arm_compute/runtime/NEON/functions/NEElementwiseUnaryLayer.h>
#include <arm_compute/runtime/NEON/functions/NEFloor.h>
#include <arm_compute/runtime/NEON/functions/NEPReluLayer.h>
#include <ngraph/runtime/reference/abs.hpp>
#include <ngraph/runtime/reference/clamp.hpp>
#include <ngraph/runtime/reference/floor.hpp>
#include <ngraph/runtime/reference/hsigmoid.hpp>
#include <ngraph/runtime/reference/hard_sigmoid.hpp>
#include <ngraph/runtime/reference/selu.hpp>
Expand Down Expand Up @@ -40,13 +43,34 @@ template<> Converter::Conversion::Ptr Converter::Convert(const opset::PRelu& nod
}

template<> Converter::Conversion::Ptr Converter::Convert(const opset::Abs& node) {
arm_compute::ActivationLayerInfo info(arm_compute::ActivationLayerInfo::ActivationFunction::ABS);
return ConvertActivation(node, info, this);
if (node.input(0).get_element_type() == ngraph::element::f32 ||
node.input(0).get_element_type() == ngraph::element::f16) {
arm_compute::ActivationLayerInfo info(arm_compute::ActivationLayerInfo::ActivationFunction::ABS);
return ConvertActivation(node, info, this);
} else {
auto make = [&] (auto refFunction) {
return this->MakeConversion(refFunction, node.input(0), node.output(0), ngraph::shape_size(node.get_output_shape(0)));
};
return CallSwitch(
AP_WRAP(make, ngraph::runtime::reference::abs),
node.input(0), intTypes);
}
}

template<> Converter::Conversion::Ptr Converter::Convert(const opset::Clamp& node) {
arm_compute::ActivationLayerInfo info(arm_compute::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, node.get_max(), node.get_min());
return ConvertActivation(node, info, this);
if (node.input(0).get_element_type() == ngraph::element::f32 ||
node.input(0).get_element_type() == ngraph::element::f16) {
arm_compute::ActivationLayerInfo info(arm_compute::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, node.get_max(), node.get_min());
return ConvertActivation(node, info, this);
} else {
auto make = [&] (auto refFunction) {
return this->MakeConversion(refFunction, node.input(0), node.output(0),
static_cast<std::int32_t>(node.get_min()), static_cast<std::int32_t>(node.get_max()), ngraph::shape_size(node.get_input_shape(0)));
};
return CallSwitch(
AP_WRAP(make, ngraph::runtime::reference::clamp),
node.input(0), std::tuple<std::int32_t>{});
}
}

template<> Converter::Conversion::Ptr Converter::Convert(const opset::Sqrt& node) {
Expand All @@ -68,7 +92,17 @@ template<> Converter::Conversion::Ptr Converter::Convert(const opset::Exp& node)
}

template<> Converter::Conversion::Ptr Converter::Convert(const opset::Floor& node) {
return MakeConversion<arm_compute::NEFloor>(node.input(0), node.output(0));
if (node.input(0).get_element_type() == ngraph::element::f32 ||
node.input(0).get_element_type() == ngraph::element::f16) {
return MakeConversion<arm_compute::NEFloor>(node.input(0), node.output(0));
} else {
auto make = [&] (auto refFunction) {
return this->MakeConversion(refFunction, node.input(0), node.output(0), ngraph::shape_size(node.get_output_shape(0)));
};
return CallSwitch(
AP_WRAP(make, ngraph::runtime::reference::floor),
node.input(0), allTypes);
}
}

template<> Converter::Conversion::Ptr Converter::Convert(const opset::HSwish& node) {
Expand All @@ -88,7 +122,7 @@ template<> Converter::Conversion::Ptr Converter::Convert(const opset::Gelu& node

template<> Converter::Conversion::Ptr Converter::Convert(const opset::Swish& node) {
float beta = 1.0;
if (ov::get_constant_from_source(node.input_value(1)) != nullptr) {
if (node.get_input_size() > 1 && ov::get_constant_from_source(node.input_value(1)) != nullptr) {
beta = ov::get_constant_from_source(node.input_value(1))->cast_vector<float>()[0];
}
arm_compute::ActivationLayerInfo info(arm_compute::ActivationLayerInfo::ActivationFunction::SWISH, beta);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ template <> Converter::Conversion::Ptr Converter::Convert(const opset::Convert&
return make(ngraph::runtime::reference::convert<float, std::uint16_t>);
case ngraph::element::Type_t::u32 :
return make(ngraph::runtime::reference::convert<float, std::uint16_t>);
case ngraph::element::Type_t::i8 :
return make(ngraph::runtime::reference::convert<float, std::int8_t>);
case ngraph::element::Type_t::i16 :
return make(ngraph::runtime::reference::convert<float, std::int16_t>);
default:
Expand Down
4 changes: 3 additions & 1 deletion modules/arm_plugin/src/arm_converter/arm_converter_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ static void FillLayerInfo(const Pool& node, arm_compute::PoolingLayerInfo& pool_
}

template<> Converter::Conversion::Ptr Converter::Convert(const opset::MaxPool& node) {
if (node.get_input_shape(0).size() == 4) {
if (node.get_input_shape(0).size() == 4 &&
(node.input(0).get_element_type() == ngraph::element::f32 ||
node.input(0).get_element_type() == ngraph::element::f16)) {
arm_compute::PoolingLayerInfo pool_info;
FillLayerInfo(node, pool_info);
pool_info.pool_type = arm_compute::PoolingType::MAX;
Expand Down

0 comments on commit ae1f96d

Please sign in to comment.