Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] implement BGR2RGB for postprocess #22022

Merged
11 changes: 11 additions & 0 deletions src/core/include/openvino/core/preprocess/output_model_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "openvino/core/core_visibility.hpp"
#include "openvino/core/layout.hpp"
#include "openvino/core/preprocess/color_format.hpp"

namespace ov {
namespace preprocess {
Expand Down Expand Up @@ -42,6 +43,16 @@ class OPENVINO_API OutputModelInfo final {
///
/// \return Reference to 'this' to allow chaining with other calls in a builder-like manner
OutputModelInfo& set_layout(const ov::Layout& layout);

/// \brief Set color format for model's output tensor
///
/// \param format Color format for model's output tensor.
///
/// \param sub_names Optional list of sub-names, not used, placeholder for future.
///
/// \return Reference to 'this' to allow chaining with other calls in a builder-like manner
OutputModelInfo& set_color_format(const ov::preprocess::ColorFormat& format,
const std::vector<std::string>& sub_names = {});
};

} // namespace preprocess
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ class OPENVINO_API PostProcessSteps final {
///
/// \return Reference to 'this' to allow chaining with other calls in a builder-like manner
PostProcessSteps& custom(const CustomPostprocessOp& postprocess_cb);

/// \brief Converts color format for user's output tensor. Requires destinantion color format to be specified by
/// OutputTensorInfo::set_color_format.
///
/// \param dst_format Destination color format of input image
///
/// \return Reference to 'this' to allow chaining with other calls in a builder-like manner
PostProcessSteps& convert_color(const ov::preprocess::ColorFormat& dst_format);
};

} // namespace preprocess
Expand Down
11 changes: 11 additions & 0 deletions src/core/src/preprocess/pre_post_process.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,12 @@ OutputModelInfo& OutputModelInfo::set_layout(const Layout& layout) {
return *this;
}

OutputModelInfo& OutputModelInfo::set_color_format(const ov::preprocess::ColorFormat& format,
const std::vector<std::string>& sub_names) {
m_impl->set_color_format(format);
return *this;
}

// --------------------- PostProcessSteps ------------------

PostProcessSteps::PostProcessSteps() : m_impl(std::unique_ptr<PostProcessStepsImpl>(new PostProcessStepsImpl())) {}
Expand All @@ -381,6 +387,11 @@ PostProcessSteps& PostProcessSteps::convert_layout(const std::vector<uint64_t>&
return *this;
}

PostProcessSteps& PostProcessSteps::convert_color(const ov::preprocess::ColorFormat& dst_format) {
m_impl->add_convert_color_impl(dst_format);
return *this;
}

PostProcessSteps& PostProcessSteps::custom(const CustomPostprocessOp& postprocess_cb) {
// 'true' indicates that custom postprocessing step will trigger validate_and_infer_types
m_impl->actions().emplace_back(
Expand Down
4 changes: 4 additions & 0 deletions src/core/src/preprocess/preprocess_impls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ void OutputInfo::OutputInfoImpl::build(ov::ResultVector& results) {
if (get_tensor_data()->is_element_type_set()) {
context.target_element_type() = get_tensor_data()->get_element_type();
}
if (get_model_data()->is_color_format_set()) {
context.color_format() = get_model_data()->get_color_format();
}

// Apply post-processing
node = result->get_input_source_output(0);
bool post_processing_applied = false;
Expand Down
21 changes: 20 additions & 1 deletion src/core/src/preprocess/preprocess_impls.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,26 @@ class ModelInfoImpl {

class InputModelInfo::InputModelInfoImpl : public ModelInfoImpl {};

class OutputModelInfo::OutputModelInfoImpl : public ModelInfoImpl {};
class OutputModelInfo::OutputModelInfoImpl : public ModelInfoImpl {
public:
void set_color_format(const ColorFormat& color_format, const std::vector<std::string>& sub_names = {}) {
m_color_format_set = (color_format == ColorFormat::RGB) || (color_format == ColorFormat::BGR);
OPENVINO_ASSERT(m_color_format_set);
m_color_format = color_format;
m_planes_sub_names = sub_names;
}
bool is_color_format_set() const {
return m_color_format_set;
}
const ColorFormat& get_color_format() const {
return m_color_format;
}

private:
ColorFormat m_color_format = ColorFormat::UNDEFINED;
std::vector<std::string> m_planes_sub_names{};
bool m_color_format_set = false;
};

/// \brief OutputInfoImpl - internal data structure
struct OutputInfo::OutputInfoImpl {
Expand Down
71 changes: 71 additions & 0 deletions src/core/src/preprocess/preprocess_steps_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,5 +695,76 @@ void PostStepsList::add_convert_layout_impl(const std::vector<uint64_t>& dims) {
},
"convert layout " + vector_to_string(dims));
}

void PostStepsList::add_convert_color_impl(const ColorFormat& dst_format) {
m_actions.emplace_back(
[dst_format](const Output<Node>& node, PostprocessingContext& context) {
if (context.color_format() == dst_format) {
return std::make_tuple(node, false);
} else if ((context.color_format() == ColorFormat::RGB || context.color_format() == ColorFormat::BGR) &&
(dst_format == ColorFormat::RGB || dst_format == ColorFormat::BGR)) {
auto res = reverse_channels({node}, context);
context.color_format() = dst_format;
return res;
} else {
OPENVINO_THROW("Source color format '",
color_format_name(context.color_format()),
"' is not convertible to '",
color_format_name(dst_format),
"'");
}
},
"convert color (" + color_format_name(dst_format) + ")");
}

std::tuple<Output<Node>, bool> PostStepsList::reverse_channels(const Output<Node>& node,
PostprocessingContext& context) {
OPENVINO_ASSERT(ov::layout::has_channels(context.layout()),
"Layout ",
context.layout().to_string(),
" doesn't have `channels` dimension");
const auto& shape = node.get_partial_shape();
if (shape.rank().is_static()) {
// This block of code is to preserve output shape if it contains dynamic dimensions
// Otherwise, dynamic version will transform shape {?,3,?,?} to {?,?,?,?} which is still ok but not desired
auto channels_idx = get_and_check_channels_idx(context.layout(), shape);
if (shape[channels_idx].is_static()) {
auto channels_count = shape[channels_idx].get_length();
// Add range from constants
auto range_from = op::v0::Constant::create(element::i64, {}, {channels_count - 1});
auto range_to = op::v0::Constant::create(element::i64, {}, {-1});
auto range_step = op::v0::Constant::create(element::i64, {}, {-1});
auto range = std::make_shared<op::v4::Range>(range_from, range_to, range_step, element::i32);

auto constant_axis = op::v0::Constant::create(element::i32, {1}, {channels_idx});
auto convert = std::make_shared<op::v8::Gather>(node, range, constant_axis);
return std::make_tuple(convert, false);
}
}

auto channels_idx = ov::layout::channels_idx(context.layout());
// Get shape of user's input tensor (e.g. Tensor[1, 3, 224, 224] -> {1, 3, 224, 224})
auto shape_of = std::make_shared<ov::op::v0::ShapeOf>(node); // E.g. {1, 3, 224, 224}

auto constant_chan_idx = op::v0::Constant::create(element::i32, {}, {channels_idx}); // E.g. 1
auto constant_chan_axis = op::v0::Constant::create(element::i32, {}, {0});
// Gather will return scalar with number of channels (e.g. 3)
auto gather_channels_num = std::make_shared<op::v8::Gather>(shape_of, constant_chan_idx, constant_chan_axis);

// Create Range from channels_num-1 to 0 (e.g. {2, 1, 0})
auto const_minus1 = op::v0::Constant::create(element::i64, {}, {-1});
auto channels_num_minus1 = std::make_shared<op::v1::Add>(gather_channels_num, const_minus1); // E.g. 3-1=2
// Add range
auto range_to = op::v0::Constant::create(element::i64, {}, {-1});
auto range_step = op::v0::Constant::create(element::i64, {}, {-1});
// E.g. {2, 1, 0}
auto range = std::make_shared<op::v4::Range>(channels_num_minus1, range_to, range_step, element::i32);

// Gather slices in reverse order (indexes are specified by 'range' operation)
auto constant_axis = op::v0::Constant::create(element::i32, {1}, {channels_idx});
auto gather = std::make_shared<op::v8::Gather>(node, range, constant_axis);
return std::make_tuple(gather, false);
}

} // namespace preprocess
} // namespace ov
22 changes: 13 additions & 9 deletions src/core/src/preprocess/preprocess_steps_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,19 @@ class PrePostProcessingContextBase {
return m_target_element_type;
}

const ColorFormat& color_format() const {
return m_color_format;
}

ColorFormat& color_format() {
return m_color_format;
}

protected:
Layout m_layout;
Layout m_target_layout;
element::Type m_target_element_type;
ColorFormat m_color_format = ColorFormat::UNDEFINED;
};

/// \brief Preprocessing context passed to each preprocessing operation.
Expand Down Expand Up @@ -126,18 +135,9 @@ class PreprocessingContext : public PrePostProcessingContextBase {
return model_shape()[model_width_idx].get_length();
}

const ColorFormat& color_format() const {
return m_color_format;
}

ColorFormat& color_format() {
return m_color_format;
}

private:
PartialShape m_model_shape;
Layout m_model_layout;
ColorFormat m_color_format = ColorFormat::UNDEFINED;
};

using InternalPreprocessOp =
Expand Down Expand Up @@ -219,6 +219,7 @@ class PostStepsList {
void add_convert_impl(const element::Type& type);
void add_convert_layout_impl(const Layout& layout);
void add_convert_layout_impl(const std::vector<uint64_t>& dims);
void add_convert_color_impl(const ColorFormat& dst_format);

const std::list<InternalPostprocessAction>& actions() const {
return m_actions;
Expand All @@ -227,6 +228,9 @@ class PostStepsList {
return m_actions;
}

private:
static std::tuple<Output<Node>, bool> reverse_channels(const Output<Node>& nodes, PostprocessingContext& context);

private:
std::list<InternalPostprocessAction> m_actions;
};
Expand Down
82 changes: 82 additions & 0 deletions src/core/tests/preprocess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,88 @@ TEST(pre_post_process, postprocess_keep_friendly_names_compatibility_implicit) {
EXPECT_NE(node_before_result_old->get_friendly_name(), node_name);
}

// --- PostProcess - convert color format ---
TEST(pre_post_process, postprocess_convert_color_format_BGR_RGB) {
auto f = create_simple_function(element::f32, Shape{5, 30, 20, 3});
auto p = PrePostProcessor(f);
p.output().model().set_layout("NHWC").set_color_format(ColorFormat::BGR);
p.output().postprocess().convert_color(ColorFormat::RGB);
f = p.build();

EXPECT_EQ(f->get_results().size(), 1);
EXPECT_EQ(f->get_result()->get_output_partial_shape(0), (PartialShape{5, 30, 20, 3}));
}

TEST(pre_post_process, postprocess_convert_color_format_RGB_BGR) {
auto f = create_simple_function(element::f32, Shape{5, 30, 20, 3});
auto p = PrePostProcessor(f);
p.output().model().set_layout("NHWC").set_color_format(ColorFormat::RGB);
p.output().postprocess().convert_color(ColorFormat::BGR);
f = p.build();

EXPECT_EQ(f->get_results().size(), 1);
EXPECT_EQ(f->get_result()->get_output_partial_shape(0), (PartialShape{5, 30, 20, 3}));
}

TEST(pre_post_process, postprocess_convert_color_format_RGB_BGR_dynamic_batch) {
auto f = create_simple_function(element::f32, PartialShape{-1, 30, 20, 3});
auto p = PrePostProcessor(f);
p.output().model().set_layout("NHWC").set_color_format(ColorFormat::RGB);
p.output().postprocess().convert_color(ColorFormat::BGR);
f = p.build();

EXPECT_EQ(f->get_results().size(), 1);
EXPECT_EQ(f->get_result()->get_output_partial_shape(0), (PartialShape{-1, 30, 20, 3}));
}

TEST(pre_post_process, postprocess_convert_color_format_RGB_BGR_dynamic_shape) {
auto f = create_simple_function(element::f32, PartialShape{-1, -1, 20, 3});
auto p = PrePostProcessor(f);
p.output().model().set_layout("NHWC").set_color_format(ColorFormat::RGB);
p.output().postprocess().convert_color(ColorFormat::BGR);
f = p.build();

EXPECT_EQ(f->get_results().size(), 1);
EXPECT_EQ(f->get_result()->get_output_partial_shape(0), (PartialShape{-1, -1, 20, 3}));
}

TEST(pre_post_process, postprocess_convert_color_format_RGB_RGB) {
auto f = create_simple_function(element::f32, Shape{5, 30, 20, 3});
auto p = PrePostProcessor(f);
p.output().model().set_layout("NHWC").set_color_format(ColorFormat::RGB);
p.output().postprocess().convert_color(ColorFormat::RGB);
f = p.build();

EXPECT_EQ(f->get_results().size(), 1);
EXPECT_EQ(f->get_result()->get_output_partial_shape(0), (PartialShape{5, 30, 20, 3}));
}

TEST(pre_post_process, postprocess_convert_color_format_BGR_BGR) {
auto f = create_simple_function(element::f32, Shape{5, 30, 20, 3});
auto p = PrePostProcessor(f);
p.output().model().set_layout("NHWC").set_color_format(ColorFormat::BGR);
p.output().postprocess().convert_color(ColorFormat::BGR);
f = p.build();

EXPECT_EQ(f->get_results().size(), 1);
EXPECT_EQ(f->get_result()->get_output_partial_shape(0), (PartialShape{5, 30, 20, 3}));
}

TEST(pre_post_process, postprocess_convert_color_format_unsupported) {
auto f = create_simple_function(element::f32, Shape{5, 30, 20, 3});

EXPECT_THROW(auto p = PrePostProcessor(f); p.output().model().set_layout("NHWC").set_color_format(ColorFormat::RGB);
p.output().postprocess().convert_color(ColorFormat::GRAY);
f = p.build(), ov::AssertFailure);

EXPECT_THROW(auto p = PrePostProcessor(f); p.output().model().set_layout("NHWC").set_color_format(ColorFormat::RGB);
p.output().postprocess().convert_color(ColorFormat::UNDEFINED);
f = p.build(), ov::AssertFailure);
EXPECT_THROW(auto p = PrePostProcessor(f); p.output().model().set_color_format(ColorFormat::UNDEFINED);
p.output().postprocess().convert_color(ColorFormat::BGR);
f = p.build(), ov::AssertFailure);
}

// Postprocessing - other

TEST(pre_post_process, postprocess_preserve_rt_info) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,10 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Combine(::testing::ValuesIn(ov::builder::preprocess::generic_preprocess_functions()),
::testing::Values(ov::test::utils::DEVICE_CPU)),
PrePostProcessTest::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(
smoke_PostProcess,
PostProcessTest,
::testing::Combine(::testing::ValuesIn(ov::builder::preprocess::generic_postprocess_functions()),
::testing::Values(ov::test::utils::DEVICE_CPU)),
PostProcessTest::getTestCaseName);
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_PrePostProcess_GPU,
::testing::Values(ov::test::utils::DEVICE_GPU)),
PrePostProcessTest::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(
smoke_PostProcess,
PostProcessTest,
::testing::Combine(::testing::ValuesIn(ov::builder::preprocess::generic_postprocess_functions()),
::testing::Values(ov::test::utils::DEVICE_GPU)),
PostProcessTest::getTestCaseName);

} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,38 @@ static RefPreprocessParams post_convert_layout_by_dims_multi() {
return res;
}

static RefPreprocessParams post_convert_color_rgb_to_bgr() {
RefPreprocessParams res("post_convert_color_rgb_to_bgr");
res.function = []() {
auto f = create_simple_function(element::f32, Shape{2, 1, 1, 3});
auto p = PrePostProcessor(f);
p.output().model().set_layout("NHWC").set_color_format(ColorFormat::RGB);
p.output().postprocess().convert_color(ColorFormat::BGR);
p.build();
return f;
};

res.inputs.emplace_back(Shape{2, 3, 1, 1}, element::f32, std::vector<float>{1, 2, 3, 4, 5, 6});
res.expected.emplace_back(Shape{2, 3, 1, 1}, element::f32, std::vector<float>{3, 2, 1, 6, 5, 4});
return res;
}

static RefPreprocessParams post_convert_color_bgr_to_rgb() {
RefPreprocessParams res("post_convert_color_bgr_to_rgb");
res.function = []() {
auto f = create_simple_function(element::f32, Shape{2, 1, 1, 3});
auto p = PrePostProcessor(f);
p.output().model().set_layout("NHWC").set_color_format(ColorFormat::BGR);
p.output().postprocess().convert_color(ColorFormat::RGB);
p.build();
return f;
};

res.inputs.emplace_back(Shape{2, 3, 1, 1}, element::f32, std::vector<float>{1, 2, 3, 4, 5, 6});
res.expected.emplace_back(Shape{2, 3, 1, 1}, element::f32, std::vector<float>{3, 2, 1, 6, 5, 4});
return res;
}

static RefPreprocessParams pre_and_post_processing() {
RefPreprocessParams res("pre_and_post_processing");
res.function = []() {
Expand Down Expand Up @@ -1382,6 +1414,8 @@ std::vector<RefPreprocessParams> allPreprocessTests() {
postprocess_2_inputs_basic(),
post_convert_layout_by_dims(),
post_convert_layout_by_dims_multi(),
post_convert_color_rgb_to_bgr(),
post_convert_color_bgr_to_rgb(),
pre_and_post_processing(),
rgb_to_bgr(),
bgr_to_rgb(),
Expand Down
Loading
Loading