Skip to content

Commit

Permalink
back port fix
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Feb 17, 2025
1 parent 3f57265 commit 22568aa
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ class InsertTailLoop : public Pass {
public:
OPENVINO_RTTI("InsertTailLoop", "Pass")
bool run(LinearIR& linear_ir) override;
static LinearIR::container copy_loop(const LinearIR& linear_ir, const size_t loop_id);
static LinearIR::constExprIt insert_copy_loop(LinearIR& linear_ir,
const size_t loop_id,
const LinearIR::constExprIt& insert_pos);

static constexpr size_t existing_subtensor_value = SIZE_MAX;
static void propagate_updated_subtensor_through_loop(const LinearIR& linear_ir,
Expand Down
9 changes: 9 additions & 0 deletions src/common/snippets/src/lowered/loop_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,15 @@ size_t LinearIR::LoopManager::replace_with_new_loop(const LinearIR& linear_ir,
for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
replace_loop_id(*expr_it, old_id, loop_id);
}

const auto old_loop_info = this->get_loop_info(old_id);
const auto old_loop_begin_pos = linear_ir.find(old_loop_info->get_entry_points().front().expr_port->get_expr());
const auto old_loop_end_pos = linear_ir.find(old_loop_info->get_exit_points().back().expr_port->get_expr());
// If new bounds are equal to old loop bounds, this means that old Loop is removed totally from LIR
// In this case old loop info must be completely removed from loop manager
if (loop_begin_pos == old_loop_begin_pos && loop_end_pos == old_loop_end_pos) {
this->remove_loop_info(old_id);
}
return loop_id;
}

Expand Down
25 changes: 12 additions & 13 deletions src/common/snippets/src/lowered/pass/insert_tail_loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,14 @@ void InsertTailLoop::propagate_updated_subtensor_through_loop(const LinearIR& li
(*expr_it)->updateShapes();
}

LinearIR::container InsertTailLoop::copy_loop(const LinearIR& linear_ir, const size_t loop_id) {
LinearIR::constExprIt InsertTailLoop::insert_copy_loop(LinearIR& linear_ir, const size_t loop_id, const LinearIR::constExprIt& insert_pos) {
const auto& loop_manager = linear_ir.get_loop_manager();
LinearIR::constExprIt loop_begin_pos, loop_end_pos;
loop_manager->get_loop_bounds(linear_ir, loop_id, loop_begin_pos, loop_end_pos, true);
ExressionMap expression_map;
const auto& loop_copy_range = LinearIR::deep_copy_range(loop_begin_pos, std::next(loop_end_pos), expression_map);
const auto new_loop_begin_pos = linear_ir.insert(insert_pos, loop_copy_range.begin(), loop_copy_range.end());
const auto new_loop_end_pos = insert_pos;

const auto original_loop_info = loop_manager->get_loop_info(loop_id);
std::vector<LinearIR::LoopManager::LoopPort> new_entry_points, new_exit_points;
Expand All @@ -156,11 +158,9 @@ LinearIR::container InsertTailLoop::copy_loop(const LinearIR& linear_ir, const s
loop_manager->update_loops_port(outer_loop_ids, expr->get_output_port(i), {expr->get_output_port(i), new_expr->get_output_port(i)}, false);
}

const auto new_loop_begin_pos = loop_copy_range.begin();
const auto new_loop_end_pos = loop_copy_range.end();
const auto new_id = loop_manager->replace_with_new_loop(linear_ir,
std::next(new_loop_begin_pos),
std::prev(new_loop_end_pos),
new_loop_begin_pos,
new_loop_end_pos,
original_loop_info->get_work_amount(),
original_loop_info->get_increment(),
new_entry_points,
Expand All @@ -169,7 +169,7 @@ LinearIR::container InsertTailLoop::copy_loop(const LinearIR& linear_ir, const s
const auto loop_end = ov::as_type_ptr<op::LoopEnd>(std::prev(new_loop_end_pos)->get()->get_node());
OPENVINO_ASSERT(loop_end, "Cloned Loop does not contain LoopEnd op at the expected place.");
loop_end->set_id(new_id);
return loop_copy_range;
return new_loop_begin_pos;
}

void InsertTailLoop::create_tail_loop(LinearIR& linear_ir,
Expand All @@ -186,17 +186,16 @@ void InsertTailLoop::create_tail_loop(LinearIR& linear_ir,
auto original_loop_info = loop_manager->get_loop_info(original_loop_id);
auto tail_loop_info = original_loop_info;
if (need_vector_loop) {
const auto new_loop_range = copy_loop(linear_ir, original_loop_id);
const auto new_loop_end = ov::as_type_ptr<op::LoopEnd>(std::prev(new_loop_range.end())->get()->get_node());
OPENVINO_ASSERT(new_loop_end, "Cloned Loop does not contain LoopEnd op at the expected place.");
tail_loop_info = original_loop_info;
original_loop_info = loop_manager->get_loop_info(new_loop_end->get_id());

// Note: new loop body is inserted before the original loop
// So new loop becomes a main vector loop, the original loop becomes tail loop
// This is done in such way to have original ops from the main body at the end:
// this allows us to conveniently interact with outer loops in further passes
linear_ir.insert(begin, new_loop_range.begin(), new_loop_range.end());
const auto new_loop_begin_pos = insert_copy_loop(linear_ir, original_loop_id, begin);
const auto new_loop_begin = ov::as_type_ptr<op::LoopBegin>(new_loop_begin_pos->get()->get_node());
OPENVINO_ASSERT(new_loop_begin, "Cloned Loop does not contain LoopBegin op at the expected place.");
const auto new_loop_end = new_loop_begin->get_loop_end();
tail_loop_info = original_loop_info;
original_loop_info = loop_manager->get_loop_info(new_loop_end->get_id());

const auto new_vector_loop_wa = original_loop_info->get_work_amount() - tail_size;
original_loop_info->set_work_amount(new_vector_loop_wa);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,18 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) {
if (work_amount <= increment)
return false;

auto new_loop_range = snippets::lowered::pass::InsertTailLoop::copy_loop(linear_ir, loop_id);
const auto firt_iter_loop_end = ov::as_type_ptr<snippets::op::LoopEnd>(std::prev(new_loop_range.end())->get()->get_node());
const auto loop_begin_it = linear_ir.find(linear_ir.get_expr_by_node(loop_end->get_loop_begin()));
const auto new_loop_begin_pos =
snippets::lowered::pass::InsertTailLoop::insert_copy_loop(linear_ir, loop_id, loop_begin_it);
const auto new_loop_begin =
ov::as_type_ptr<snippets::op::LoopBegin>(new_loop_begin_pos->get()->get_node());
OPENVINO_ASSERT(new_loop_begin, "Cloned Loop does not contain LoopBegin op at the expected place.");
const auto firt_iter_loop_end = new_loop_begin->get_loop_end();
auto first_iter_loop_info = loop_manager->get_loop_info(firt_iter_loop_end->get_id());
firt_iter_loop_end->set_work_amount(increment);
first_iter_loop_info->set_work_amount(increment);
firt_iter_loop_end->set_finalization_offsets(std::vector<int64_t>(loop_end->get_finalization_offsets().size(), 0));

const auto loop_begin_it = linear_ir.find(linear_ir.get_expr_by_node(loop_end->get_loop_begin()));
linear_ir.insert(loop_begin_it, new_loop_range.begin(), new_loop_range.end());

const auto new_work_amount = work_amount - increment;
loop_info->set_work_amount(new_work_amount);
loop_end->set_work_amount(new_work_amount);
Expand Down

0 comments on commit 22568aa

Please sign in to comment.