-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat]: index_fill_ frontend pytorch op (#27420)
### Details: Add support for aten::index_fill_ for pytorch models ### Tickets: [#23326](#23326) --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
- Loading branch information
Showing
3 changed files
with
154 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include <memory> | ||
|
||
#include "openvino/core/node.hpp" | ||
#include "openvino/core/node_output.hpp" | ||
#include "openvino/core/type/element_type.hpp" | ||
#include "openvino/frontend/pytorch/node_context.hpp" | ||
#include "openvino/op/broadcast.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/convert_like.hpp" | ||
#include "openvino/op/reshape.hpp" | ||
#include "openvino/op/scatter_elements_update.hpp" | ||
#include "openvino/op/shape_of.hpp" | ||
#include "openvino/op/slice.hpp" | ||
#include "utils.hpp" | ||
|
||
using namespace ov::op; | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace pytorch { | ||
namespace op { | ||
|
||
OutputVector translate_index_fill_(const NodeContext& context) { | ||
// aten::index_fill_(self, dim, index, value) --> Tensor | ||
num_inputs_check(context, 4, 4); | ||
auto input = context.get_input(0); | ||
auto dim = context.get_input(1); | ||
auto index = context.get_input(2); | ||
auto value = context.get_input(3); | ||
|
||
auto const_1_vec = v0::Constant::create(element::i32, Shape{1}, {1}); | ||
|
||
auto tensor_rank = std::get<1>(get_shape_rank(context, input, false)); | ||
auto tensor_rank_correct_type = context.mark_node(std::make_shared<v1::ConvertLike>(tensor_rank, dim)); | ||
auto dim_vec = normalize_axis(context, dim, tensor_rank_correct_type); | ||
|
||
// scalar to vec | ||
auto value_vec = context.mark_node(std::make_shared<v1::Reshape>(value, const_1_vec, false)); | ||
|
||
auto input_shape = std::get<0>(get_shape_rank(context, input, false)); | ||
|
||
auto index_shape = std::get<0>(get_shape_rank(context, index, false)); | ||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); | ||
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); | ||
auto index_len = context.mark_node(std::make_shared<v8::Slice>(index_shape, const_0, const_1, const_1)); | ||
|
||
// [A, B, ..., T, ..., K] --> [A, B, ..., len(index), ..., K] | ||
auto target_shape = std::make_shared<v12::ScatterElementsUpdate>(input_shape, | ||
dim_vec, | ||
index_len, | ||
v0::Constant::create(element::i32, Shape{}, {0})); | ||
|
||
// broadcast && index fill | ||
auto broadcasted_value = context.mark_node(std::make_shared<v1::Broadcast>(value_vec, target_shape, dim_vec)); | ||
auto broadcasted_index = context.mark_node(std::make_shared<v1::Broadcast>(index, target_shape, dim_vec)); | ||
auto result = context.mark_node( | ||
std::make_shared<v12::ScatterElementsUpdate>(input, broadcasted_index, broadcasted_value, dim)); | ||
|
||
return {result}; | ||
}; | ||
|
||
} // namespace op | ||
} // namespace pytorch | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright (C) 2018-2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
|
||
from pytorch_layer_test_class import PytorchLayerTest | ||
|
||
|
||
class TestIndexFill(PytorchLayerTest): | ||
def _prepare_input(self): | ||
return (self.input_tensor,) | ||
|
||
def create_model(self, dim, index, values): | ||
class aten_index_fill_(torch.nn.Module): | ||
def __init__(self, dim, index, values): | ||
super().__init__() | ||
self.dim = dim | ||
self.index = index | ||
self.values = values | ||
|
||
def forward(self, input_tensor): | ||
input_tensor.index_fill_(self.dim, self.index, self.values) | ||
return input_tensor | ||
|
||
ref_net = None | ||
|
||
return aten_index_fill_(dim, index, values), ref_net, "aten::index_fill_" | ||
|
||
@pytest.mark.parametrize( | ||
"input_data", | ||
( | ||
{ | ||
"input_shape": [10], | ||
"dim": 0, | ||
"input_value": 5.6, | ||
"index": [5, 6, 7] | ||
}, | ||
{ | ||
"input_shape": [3, 3], | ||
"dim": 0, | ||
"input_value": 10.1, | ||
"index": [1, 0] | ||
}, | ||
{ | ||
"input_shape": [4, 3, 5], | ||
"dim": 1, | ||
"input_value": 1234.5, | ||
"index": [2, 0] | ||
}, | ||
{ | ||
"input_shape": [5, 6, 7, 8], | ||
"dim": -2, | ||
"input_value": 0.1234, | ||
"index": [6, 4, 2, 0] | ||
}, | ||
{ | ||
"input_shape": [5, 6, 7, 8], | ||
"dim": -3, | ||
"input_value": -4321234.5678765, | ||
"index": [5, 4, 3, 1] | ||
}, | ||
{ | ||
"input_shape": [5, 6, 7, 8], | ||
"dim": 3, | ||
"input_value": -1234.54321, | ||
"index": [6, 4, 7, 2, 1] | ||
}, | ||
), | ||
) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
def test_index_fill_single_index(self, ie_device, precision, ir_version, input_data): | ||
self.input_tensor = np.random.randn(*input_data["input_shape"]).astype(np.float32) | ||
values = torch.tensor(np.float32(input_data["input_value"])) | ||
dim = input_data["dim"] | ||
shape = self.input_tensor.shape | ||
max_idx = shape[dim] | ||
n_select = np.random.randint(1, max_idx + 1) | ||
index = torch.from_numpy(np.random.choice(np.arange(0, max_idx), n_select, replace=False)).to(torch.long) | ||
self._test(*self.create_model(dim, index, values), ie_device, precision, ir_version) |