Skip to content

Commit

Permalink
[Snippets] Fixed LoopManager::update_loop_ports (openvinotoolkit#27300)
Browse files Browse the repository at this point in the history
### Details:
- *To remind, `LoopPort` is expression port connected to another
expression port which is not in the same Loop. It's like entry (of exit)
point of the Loop. It means that some expression port cannot be port of
the Loop if all consumers (or sources) are from the same Loop. However,
the method `LoopManager::update_loop_ports` sometimes creates these
situation. This PR fixes this method. The screenshot below describes
this situation: red loop is inner loop and blue loops is outer loop.
However, some of output ports of this Loop is inside (green question
sign) - invalid situation which is fixed by these changes.*
<img width="233" alt="image"
src="https://github.com/user-attachments/assets/88bc6faf-edeb-49c0-b262-55922a884725">

 - *Added the corresponding checks to validate pass*
- *Remove parts in `init_is_incremented` which handle invalid case by
`is_incremented=false`.*

### Tickets:
 - *CVS-156299*
  • Loading branch information
a-sidorova authored Nov 6, 2024
1 parent 362ebe9 commit b6fe65f
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class InitLoops : public Pass {
bool run(LinearIR& linear_ir) override;

private:
static void update_compile_parameters(const UnifiedLoopInfoPtr& loop_info, size_t loop_id);
static void update_compile_parameters(const UnifiedLoopInfoPtr& loop_info);
};

} // namespace pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

#include "pass.hpp"

#include "snippets/lowered/loop_manager.hpp"


namespace ov {
namespace snippets {
namespace lowered {
Expand All @@ -27,6 +30,10 @@ class ValidateUnifiedLoops : public Pass {
OPENVINO_RTTI("ValidateUnifiedLoops", "Pass")
ValidateUnifiedLoops() = default;
bool run(LinearIR& linear_ir) override;

private:
static void validate_loop_infos(const LoopManagerPtr& loop_manager);
static void validate_loop_port_presence(const LinearIR& linear_ir);
};

} // namespace pass
Expand Down
6 changes: 6 additions & 0 deletions src/common/snippets/include/snippets/utils/loop_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ void update_data_pointer_shifts(const ov::snippets::lowered::UnifiedLoopInfoPtr&
* @brief Updates work amount and updates data pointer shifts of the provided "loop_info"
*/
void update_runtime_parameters(const ov::snippets::lowered::UnifiedLoopInfoPtr& loop_info);
/**
* @brief Check if the passed expression port should be port of the Loop with ID `loop_id`:
* the target expression port should be connected to an expression from another Loop (missed in the loop with ID `loop_id`),
*/
bool should_be_loop_port(const ov::snippets::lowered::ExpressionPort& port, size_t loop_id);

} // namespace utils
} // namespace snippets
} // namespace ov
52 changes: 34 additions & 18 deletions src/common/snippets/src/lowered/loop_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "openvino/core/graph_util.hpp"
#include "openvino/core/type.hpp"

#include "snippets/utils/loop_utils.hpp"
#include "snippets/itt.hpp"


Expand Down Expand Up @@ -349,30 +350,45 @@ void LoopManager::fuse_loop_ports(std::vector<LoopPort>& output_ports,
}

void LoopManager::update_loop_ports(const ExpressionPtr& expr) {
auto output_ports = expr->get_output_ports();
for (size_t i = 0; i < expr->get_input_count(); ++i) {
const auto& source = expr->get_input_port_connector(i)->get_source();
const auto common_outer_loop_ids = get_common_outer_loops(expr, source.get_expr());
// The source output port can have several consumers (including the current expr) that can be potential output ports
// So we should verify on the possible future output ports
size_t count_of_common_outer_loops = common_outer_loop_ids.size();
for (const auto& source_consumer : source.get_connected_ports()) {
if (source_consumer.get_expr() == expr)
auto update_ports = [&](const ov::snippets::lowered::ExpressionPort& connected_port) {
const auto is_output = connected_port.get_type() == ExpressionPort::Output;
// Iterate through all Loops of the connected expression
for (const auto& loop_id : connected_port.get_expr()->get_loop_ids()) {
const auto& loop_info = get_loop_info(loop_id);
// If the connected expression port is not Loop port - nothing to update
// If the target expression is not from the same Loop - nothing to update
if (!loop_info->is_loop_port(connected_port) || !is_loop_id_found(expr, loop_id))
continue;
count_of_common_outer_loops = std::min(count_of_common_outer_loops, get_common_outer_loops(source.get_expr(), source_consumer.get_expr()).size());
}
replace_loop_ports({common_outer_loop_ids.cbegin(), common_outer_loop_ids.cbegin() + count_of_common_outer_loops}, source, output_ports);
// Save previous port
if (count_of_common_outer_loops != common_outer_loop_ids.size()) {
output_ports.insert(output_ports.begin(), source);
replace_loop_ports({common_outer_loop_ids.cbegin() + count_of_common_outer_loops, common_outer_loop_ids.cend()}, source, output_ports);

std::vector<ExpressionPort> new_ports;
// Check if some ports of target expression must be Loop port
const auto target_expr_ports = is_output ? expr->get_output_ports() : expr->get_input_ports();
for (const auto& port : target_expr_ports) {
if (utils::should_be_loop_port(port, loop_id))
new_ports.push_back(port);
}
// Leave the connected expression port as Loop port if needed
if (utils::should_be_loop_port(connected_port, loop_id))
new_ports.push_back(connected_port);

// Nothing should be updated
if (new_ports.size() == 1 && new_ports.front() == connected_port)
continue;

loop_info->replace_with_new_ports(connected_port, new_ports);
}
};

// The case with parent loops: source -> target expr
for (size_t i = 0; i < expr->get_input_count(); ++i) {
update_ports(expr->get_input_port_connector(i)->get_source());
}
const auto input_ports = expr->get_input_ports();

// The case with child loops: target expr -> consumers
for (size_t i = 0; i < expr->get_output_count(); ++i) {
const auto& consumers = expr->get_output_port_connector(i)->get_consumers();
for (const auto& consumer : consumers) {
replace_loop_ports(get_common_outer_loops(expr, consumer.get_expr()), consumer, input_ports);
update_ports(consumer);
}
}
}
Expand Down
50 changes: 6 additions & 44 deletions src/common/snippets/src/lowered/pass/init_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,10 @@ namespace lowered {
namespace pass {

namespace {
inline void init_is_incremented(LoopPort& port, size_t loop_id) {
inline void init_is_incremented(LoopPort& port) {
const auto& expr = port.expr_port->get_expr();
const auto& expr_loops = expr->get_loop_ids();
if (!std::dynamic_pointer_cast<modifier::MemoryAccess>(expr->get_node())) {
port.is_incremented = false;
} else if (expr_loops.back() != loop_id) {
// Note: LoopPort connected to Buffer between two loops should not be incremented in the outermost loop
// Consider the example below:
// Store; Loop ids [0,1,2,3]
// Buffer; Loop ids [0,1]
// Load; Loop ids [0,1,4,5]
// Store is output port of Loop-1, but it should be incremented only in Loop-2 and Loop-3. Similar with Load.
auto is_ignored = [=](const ExpressionPtr& target_expr) {
if (ov::is_type<BufferExpression>(target_expr)) {
const auto& target_loops = target_expr->get_loop_ids();
const auto i_max = std::min(expr_loops.size(), target_loops.size());
for (size_t i = 0; i < i_max && expr_loops[i] == target_loops[i]; i++) {
if (target_loops[i] == loop_id)
return true;
}
}
return false;
};
if (port.expr_port->get_type() == ExpressionPort::Type::Output) {
const auto& out_connector = expr->get_output_port_connector(port.expr_port->get_index());
for (const auto& consumer : out_connector->get_consumers()) {
if (is_ignored(consumer.get_expr())) {
port.is_incremented = false;
return;
}
}
} else if (port.expr_port->get_type() == ExpressionPort::Type::Input) {
const auto& in_connector = expr->get_input_port_connector(port.expr_port->get_index());
if (is_ignored(in_connector->get_source().get_expr())) {
port.is_incremented = false;
return;
}
} else {
OPENVINO_THROW("Unexpected LoopPort type");
}
}
}

Expand All @@ -71,11 +35,11 @@ inline int64_t get_data_size(const LoopPort& loop_port) {
}
} // namespace

void InitLoops::update_compile_parameters(const UnifiedLoopInfoPtr& loop_info, size_t loop_id) {
void InitLoops::update_compile_parameters(const UnifiedLoopInfoPtr& loop_info) {
OPENVINO_ASSERT(loop_info != nullptr, "UnifiedLoopInfo is nullptr, nothing to update");
loop_info->iterate_through_infos(
[loop_id](LoopPort& loop_port, UnifiedLoopInfo::LoopPortDesc& ptr_shifts_params) {
init_is_incremented(loop_port, loop_id);
[](LoopPort& loop_port, UnifiedLoopInfo::LoopPortDesc& ptr_shifts_params) {
init_is_incremented(loop_port);
ptr_shifts_params.data_size = get_data_size(loop_port);
});
}
Expand All @@ -85,12 +49,10 @@ bool InitLoops::run(LinearIR& linear_ir) {
if (linear_ir.empty())
return false;

const auto& loop_manager = linear_ir.get_loop_manager();
const auto& loops = loop_manager->get_map();
const auto& loops = linear_ir.get_loop_manager()->get_map();
for (const auto& loop : loops) {
const auto& loop_id = loop.first;
const auto& loop_info = ov::as_type_ptr<UnifiedLoopInfo>(loop.second);
update_compile_parameters(loop_info, loop_id);
update_compile_parameters(loop_info);
ov::snippets::utils::update_runtime_parameters(loop_info);
}

Expand Down
55 changes: 44 additions & 11 deletions src/common/snippets/src/lowered/pass/validate_unified_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,15 @@
#include "snippets/itt.hpp"
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/loop_manager.hpp"
#include "snippets/utils/loop_utils.hpp"
#include "snippets/utils/utils.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

bool ValidateUnifiedLoops::run(LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ValidateUnifiedLoops")
if (linear_ir.empty())
return false;

const auto& loop_manager = linear_ir.get_loop_manager();
const auto& loops = loop_manager->get_map();

void ValidateUnifiedLoops::validate_loop_infos(const LoopManagerPtr& loop_manager) {
// Already validated vectors of Loop IDs
std::set<std::vector<size_t>> validated_nested_loops;
auto is_already_verified = [&validated_nested_loops](const std::vector<size_t>& ids) {
Expand Down Expand Up @@ -66,10 +60,9 @@ bool ValidateUnifiedLoops::run(LinearIR& linear_ir) {
validated_nested_loops.insert(loop_ids);
};

for (const auto& pair : loops) {
for (const auto& pair : loop_manager->get_map()) {
const auto& loop_info = ov::as_type_ptr<UnifiedLoopInfo>(pair.second);
OPENVINO_ASSERT(loop_info,
"ValidateUnifiedLoops expects only UnifiedLoopInfo in LoopManager");
OPENVINO_ASSERT(loop_info, "ValidateUnifiedLoops expects only UnifiedLoopInfo in LoopManager");
loop_info->iterate_through_ports(validate_loop_port);

// Validate that iteration dimnsion is broadcastable
Expand All @@ -88,6 +81,46 @@ bool ValidateUnifiedLoops::run(LinearIR& linear_ir) {
OPENVINO_ASSERT(unique_dimensions.size() <= 1,
"Loop ports have incompatible dimensions, by which the loop iterates");
}
}

void ValidateUnifiedLoops::validate_loop_port_presence(const LinearIR& linear_ir) {
auto validate_loop_port = [](const ExpressionPort& expr_port, const LoopInfoPtr& loop_info, size_t loop_id) {
if (utils::should_be_loop_port(expr_port, loop_id)) {
OPENVINO_ASSERT(loop_info->is_loop_port(expr_port),
"Expression port with idx ", expr_port.get_index(), " with node ",
expr_port.get_expr()->get_node()->get_friendly_name(), " is not Loop port but should be!");
} else {
OPENVINO_ASSERT(!loop_info->is_loop_port(expr_port),
"Expression port with idx ", expr_port.get_index(), " with node ",
expr_port.get_expr()->get_node()->get_friendly_name(), " is Loop port but should not be!");
}
};

const auto& loop_manager = linear_ir.get_loop_manager();
for (const auto& expr : linear_ir) {
const auto& op = expr->get_node();
if (ov::is_type<ov::snippets::op::LoopBase>(op))
continue;

for (const auto& loop_id : expr->get_loop_ids()) {
const auto& loop_info = loop_manager->get_loop_info(loop_id);

for (size_t i = 0; i < expr->get_input_count(); ++i)
validate_loop_port(expr->get_input_port(i), loop_info, loop_id);

for (size_t i = 0; i < expr->get_output_count(); ++i)
validate_loop_port(expr->get_output_port(i), loop_info, loop_id);
}
}
}

bool ValidateUnifiedLoops::run(LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ValidateUnifiedLoops")
if (linear_ir.empty())
return false;

validate_loop_infos(linear_ir.get_loop_manager());
validate_loop_port_presence(linear_ir);

return true;
}
Expand Down
9 changes: 9 additions & 0 deletions src/common/snippets/src/utils/loop_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ void update_runtime_parameters(const UnifiedLoopInfoPtr& loop_info) {
update_data_pointer_shifts(loop_info);
}

bool should_be_loop_port(const ov::snippets::lowered::ExpressionPort& port, size_t loop_id) {
const auto& connected_ports = port.get_connected_ports();
return std::any_of(connected_ports.cbegin(), connected_ports.cend(),
[&](const ExpressionPort& connected_port) {
const auto& loops = connected_port.get_expr()->get_loop_ids();
return std::find(loops.cbegin(), loops.cend(), loop_id) == loops.cend();
});
}

} // namespace utils
} // namespace snippets
} // namespace ov

0 comments on commit b6fe65f

Please sign in to comment.