Skip to content

Commit

Permalink
[feat]: index_fill_ frontend pytorch op (#27420)
Browse files Browse the repository at this point in the history
### Details:

Add support for aten::index_fill_ for pytorch models

### Tickets:

[#23326](#23326)

---------

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
  • Loading branch information
cocoshe and mvafin authored Nov 26, 2024
1 parent 7568d5e commit 17e17b2
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 0 deletions.
69 changes: 69 additions & 0 deletions src/frontends/pytorch/src/op/index_fill_.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <memory>

#include "openvino/core/node.hpp"
#include "openvino/core/node_output.hpp"
#include "openvino/core/type/element_type.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_elements_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "utils.hpp"

using namespace ov::op;

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_index_fill_(const NodeContext& context) {
// aten::index_fill_(self, dim, index, value) --> Tensor
num_inputs_check(context, 4, 4);
auto input = context.get_input(0);
auto dim = context.get_input(1);
auto index = context.get_input(2);
auto value = context.get_input(3);

auto const_1_vec = v0::Constant::create(element::i32, Shape{1}, {1});

auto tensor_rank = std::get<1>(get_shape_rank(context, input, false));
auto tensor_rank_correct_type = context.mark_node(std::make_shared<v1::ConvertLike>(tensor_rank, dim));
auto dim_vec = normalize_axis(context, dim, tensor_rank_correct_type);

// scalar to vec
auto value_vec = context.mark_node(std::make_shared<v1::Reshape>(value, const_1_vec, false));

auto input_shape = std::get<0>(get_shape_rank(context, input, false));

auto index_shape = std::get<0>(get_shape_rank(context, index, false));
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto index_len = context.mark_node(std::make_shared<v8::Slice>(index_shape, const_0, const_1, const_1));

// [A, B, ..., T, ..., K] --> [A, B, ..., len(index), ..., K]
auto target_shape = std::make_shared<v12::ScatterElementsUpdate>(input_shape,
dim_vec,
index_len,
v0::Constant::create(element::i32, Shape{}, {0}));

// broadcast && index fill
auto broadcasted_value = context.mark_node(std::make_shared<v1::Broadcast>(value_vec, target_shape, dim_vec));
auto broadcasted_index = context.mark_node(std::make_shared<v1::Broadcast>(index, target_shape, dim_vec));
auto result = context.mark_node(
std::make_shared<v12::ScatterElementsUpdate>(input, broadcasted_index, broadcasted_value, dim));

return {result};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ OP_CONVERTER(translate_im2col);
OP_CONVERTER(translate_index);
OP_CONVERTER(translate_index_add);
OP_CONVERTER(translate_index_copy_);
OP_CONVERTER(translate_index_fill_);
OP_CONVERTER(translate_index_put_);
OP_CONVERTER(translate_index_select);
OP_CONVERTER(translate_instance_norm);
Expand Down Expand Up @@ -496,6 +497,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
// aten::imag - Supported in limited set of patterns
// aten::index - Supported in limited set of patterns
{"aten::index_copy_", op::inplace_op<op::translate_index_copy_>},
{"aten::index_fill_", op::inplace_op<op::translate_index_fill_>},
{"aten::index_put_", op::inplace_op<op::translate_index_put_>},
{"aten::index_add", op::translate_index_add},
{"aten::index_select", op::translate_index_select},
Expand Down
83 changes: 83 additions & 0 deletions tests/layer_tests/pytorch_tests/test_index_fill_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import numpy as np
import pytest
import torch

from pytorch_layer_test_class import PytorchLayerTest


class TestIndexFill(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor,)

def create_model(self, dim, index, values):
class aten_index_fill_(torch.nn.Module):
def __init__(self, dim, index, values):
super().__init__()
self.dim = dim
self.index = index
self.values = values

def forward(self, input_tensor):
input_tensor.index_fill_(self.dim, self.index, self.values)
return input_tensor

ref_net = None

return aten_index_fill_(dim, index, values), ref_net, "aten::index_fill_"

@pytest.mark.parametrize(
"input_data",
(
{
"input_shape": [10],
"dim": 0,
"input_value": 5.6,
"index": [5, 6, 7]
},
{
"input_shape": [3, 3],
"dim": 0,
"input_value": 10.1,
"index": [1, 0]
},
{
"input_shape": [4, 3, 5],
"dim": 1,
"input_value": 1234.5,
"index": [2, 0]
},
{
"input_shape": [5, 6, 7, 8],
"dim": -2,
"input_value": 0.1234,
"index": [6, 4, 2, 0]
},
{
"input_shape": [5, 6, 7, 8],
"dim": -3,
"input_value": -4321234.5678765,
"index": [5, 4, 3, 1]
},
{
"input_shape": [5, 6, 7, 8],
"dim": 3,
"input_value": -1234.54321,
"index": [6, 4, 7, 2, 1]
},
),
)
@pytest.mark.nightly
@pytest.mark.precommit
def test_index_fill_single_index(self, ie_device, precision, ir_version, input_data):
self.input_tensor = np.random.randn(*input_data["input_shape"]).astype(np.float32)
values = torch.tensor(np.float32(input_data["input_value"]))
dim = input_data["dim"]
shape = self.input_tensor.shape
max_idx = shape[dim]
n_select = np.random.randint(1, max_idx + 1)
index = torch.from_numpy(np.random.choice(np.arange(0, max_idx), n_select, replace=False)).to(torch.long)
self._test(*self.create_model(dim, index, values), ie_device, precision, ir_version)

0 comments on commit 17e17b2

Please sign in to comment.