Skip to content

Commit

Permalink
[PT FE] Support setting types/shapes for nested structures (#25064)
Browse files Browse the repository at this point in the history
### Details:
- *Some PyTorch models may have `dicts`, `lists`, `tuples` as inputs. We
unpack them on `normalize` step in PT FE, but this means that those
unpacked inputs do not exist before conversion and `convert_model` can't
set shapes/types for them. This PR adds ability for `convert_model` to
set shapes/types to any arbitrary input obtained by index or name.*

### Tickets:
 - *CVS-142939*
  • Loading branch information
mvafin authored Jul 1, 2024
1 parent 1d7daae commit a3d2b6a
Show file tree
Hide file tree
Showing 23 changed files with 537 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, pt_module, fx_gm=None, nodes=None, mark_node_callback=None, i
self.input_shapes = input_shapes

self._input_signature = []
self._example_input = None

if issubclass(type(pt_module), torch.fx.graph_module.GraphModule):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False)
if isinstance(pt_module, torch.nn.Module):
pt_module.eval()
input_signature = None
input_parameters = None
if isinstance(pt_module, torch.nn.Module) and not isinstance(pt_module, (torch.jit._trace.TopLevelTracedModule, torch.jit._script.RecursiveScriptModule)):
# input params is dictionary contains input names and their signature values (type hints and default values if any)
input_params = inspect.signature(pt_module.forward if hasattr(
Expand Down Expand Up @@ -150,8 +151,10 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False)
scripted, preserved_attrs=preserved_attrs)
else:
f_model = scripted
self._example_input = input_parameters["example_inputs"] if input_parameters else None
else:
f_model = pt_module
self._example_input = example_inputs

self._input_signature = input_signature
return f_model
Expand Down
12 changes: 12 additions & 0 deletions src/bindings/python/src/pyopenvino/frontend/input_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ void regclass_frontend_InputModel(py::module m) {
:rtype: openvino.frontend.Place
)");

im.def("get_place_by_input_index",
&ov::frontend::InputModel::get_place_by_input_index,
py::arg("input_idx"),
R"(
Returns a tensor place by an input index.
:param input_idx: Index of model input.
:type input_idx: int
:return: Tensor place corresponding to specified input index or nullptr.
:rtype: openvino.frontend.Place
)");

im.def("get_place_by_operation_name",
&ov::frontend::InputModel::get_place_by_operation_name,
py::arg("operation_name"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ class FRONTEND_API InputModel {
/// \return Tensor place corresponding to specified tensor name or nullptr if not exists
virtual Place::Ptr get_place_by_tensor_name(const std::string& tensor_name) const;

/// \brief Returns a tensor place by an input index.
/// \param input_idx Index of model input
/// \return Tensor place corresponding to specified input index or nullptr
virtual Place::Ptr get_place_by_input_index(size_t input_idx) const;

/// \brief Returns an operation place by an operation name following framework
/// conventions, or nullptr if an operation with this name doesn't exist.
/// \param operation_name Name of operation
Expand Down
7 changes: 7 additions & 0 deletions src/frontends/common/src/input_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ Place::Ptr InputModel::get_place_by_tensor_name(const std::string& tensor_name)
FRONTEND_RETURN_STATEMENT("get_place_by_tensor_name", m_actual->get_place_by_tensor_name(tensor_name))
}

Place::Ptr InputModel::get_place_by_input_index(size_t input_idx) const {
if (!m_actual) {
return {};
}
FRONTEND_RETURN_STATEMENT("get_place_by_input_index", m_actual->get_place_by_input_index(input_idx))
}

Place::Ptr InputModel::get_place_by_operation_name(const std::string& operation_name) const {
if (!m_actual) {
return {};
Expand Down
4 changes: 4 additions & 0 deletions src/frontends/onnx/frontend/src/input_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ ov::frontend::Place::Ptr InputModel::get_place_by_tensor_name(const std::string&
return nullptr;
}

ov::frontend::Place::Ptr InputModel::get_place_by_input_index(size_t input_idx) const {
FRONT_END_NOT_IMPLEMENTED(get_place_by_input_index);
}

ov::frontend::Place::Ptr InputModel::get_place_by_operation_name(const std::string& operation_name) const {
if (m_editor->is_correct_and_unambiguous_node(operation_name)) {
const auto node_index = m_editor->get_node_index(EditorNode{operation_name});
Expand Down
1 change: 1 addition & 0 deletions src/frontends/onnx/frontend/src/input_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class InputModel : public ov::frontend::InputModel {
std::vector<ov::frontend::Place::Ptr> get_inputs() const override;
std::vector<ov::frontend::Place::Ptr> get_outputs() const override;
ov::frontend::Place::Ptr get_place_by_tensor_name(const std::string& tensor_name) const override;
ov::frontend::Place::Ptr get_place_by_input_index(size_t input_idx) const override;
ov::frontend::Place::Ptr get_place_by_operation_name(const std::string& operation_name) const override;
ov::frontend::Place::Ptr get_place_by_operation_name_and_input_port(const std::string& operation_name,
int input_port_index) override;
Expand Down
4 changes: 4 additions & 0 deletions src/frontends/paddle/src/input_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,10 @@ Place::Ptr InputModel::get_place_by_tensor_name(const std::string& tensorName) c
return _impl->get_place_by_tensor_name(tensorName);
}

Place::Ptr InputModel::get_place_by_input_index(size_t input_idx) const {
FRONT_END_NOT_IMPLEMENTED(get_place_by_input_index);
}

void InputModel::override_all_outputs(const std::vector<Place::Ptr>& outputs) {
_impl->override_all_outputs(outputs);
}
Expand Down
1 change: 1 addition & 0 deletions src/frontends/paddle/src/input_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class InputModel : public ov::frontend::InputModel {
std::vector<Place::Ptr> get_inputs() const override;
std::vector<Place::Ptr> get_outputs() const override;
Place::Ptr get_place_by_tensor_name(const std::string& tensorName) const override;
Place::Ptr get_place_by_input_index(size_t input_idx) const override;
void override_all_outputs(const std::vector<Place::Ptr>& outputs) override;
void override_all_inputs(const std::vector<Place::Ptr>& inputs) override;
void extract_subgraph(const std::vector<Place::Ptr>& inputs, const std::vector<Place::Ptr>& outputs) override;
Expand Down
73 changes: 71 additions & 2 deletions src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "openvino/op/util/multi_subgraph_base.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "openvino/util/log.hpp"
#include "place.hpp"
#include "pt_framework_node.hpp"
#include "transformations/common_optimizations/push_constant_to_subgraph.hpp"
#include "transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.hpp"
Expand Down Expand Up @@ -113,15 +114,37 @@ std::string pack_detailed_failure_report(const std::map<std::string, std::string
error_msg << '\n' << failed_ops_short.str();
return error_msg.str();
}

void update_parameter_info(std::shared_ptr<ov::op::v0::Parameter>& param,
const Place::Ptr& place,
const std::shared_ptr<Model> ov_model) {
const auto& pt_place = std::dynamic_pointer_cast<pytorch::Place>(place);
FRONT_END_GENERAL_CHECK(pt_place, "Wrong place type.");
if (pt_place->get_element_type().is_static())
param->set_element_type(pt_place->get_element_type());
auto pshape = param->get_partial_shape();
FRONT_END_GENERAL_CHECK(PartialShape::merge_into(pshape, pt_place->get_partial_shape()),
"Incompatible shape was requested for input #",
ov_model->get_parameter_index(param),
", ",
param,
"\nOriginal shape: ",
pshape,
" update shape: ",
pt_place->get_partial_shape());
param->set_partial_shape(pshape);
}
} // namespace

FrontEnd::FrontEnd() {}

std::shared_ptr<Model> FrontEnd::convert(const ov::frontend::InputModel::Ptr& model) const {
FRONT_END_GENERAL_CHECK(std::dynamic_pointer_cast<pytorch::InputModel>(model), "Invalid input model");
auto pt_model = std::dynamic_pointer_cast<pytorch::InputModel>(model);
FRONT_END_GENERAL_CHECK(pt_model, "Invalid input model");
std::map<std::string, CreatorFunction> supported_ops = get_supported_ops(model);
std::shared_ptr<Model> converted_model;
{
pt_model->flush_places();
TranslateSession translate_session(model, supported_ops, m_telemetry);
converted_model = translate_session.get_converted_model();
}
Expand All @@ -141,6 +164,50 @@ std::shared_ptr<Model> FrontEnd::convert(const ov::frontend::InputModel::Ptr& mo
}
bool is_conversion_successful = unconverted_ops.size() == 0 && norm_err.empty();
FRONT_END_OP_CONVERSION_CHECK(is_conversion_successful, pack_detailed_failure_report(unconverted_ops, norm_err));

if (pt_model->m_requested_places.size() != 0) {
// Fake tensors mean that types were set to non-existent before conversion inputs.
// Here we resolve this. If input doesn't exist after conversion exception is raised.
auto parameters = converted_model->get_parameters();
std::set<std::string> input_names;
for (const auto& param : parameters)
for (const auto& name : param->get_output_tensor(0).get_names())
input_names.insert(name);
const auto& inputs = pt_model->get_inputs();
for (size_t i = 0; i < inputs.size(); i++) {
auto place = inputs[i];
if (place->get_names().size() != 0 && input_names.find(place->get_names().at(0)) != input_names.end()) {
auto input = converted_model->input(place->get_names().at(0));
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(input.get_node_shared_ptr());
FRONT_END_GENERAL_CHECK(param, "Input is not a Parameter.");
update_parameter_info(param, place, converted_model);
} else {
FRONT_END_OP_CONVERSION_CHECK(i < parameters.size(),
"Type/shape was set to non-existent input. Converted model:\n",
converted_model);
update_parameter_info(parameters[i], inputs[i], converted_model);
}
}
for (const auto& fplace : pt_model->m_requested_places) {
const auto& pt_place = std::dynamic_pointer_cast<pytorch::Place>(fplace);
FRONT_END_GENERAL_CHECK(pt_place, "Wrong place type.");
if (fplace->get_names().size() == 0) {
// Fake place was set by input index
auto idx = pt_place->get_input_index();
FRONT_END_GENERAL_CHECK(idx < parameters.size(),
"Type/shape was set to non-existent input. Converted model:\n",
converted_model);
update_parameter_info(parameters[idx], fplace, converted_model);
} else {
auto input = converted_model->input(fplace->get_names().at(0));
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(input.get_node_shared_ptr());
FRONT_END_GENERAL_CHECK(param, "Input is not a Parameter.");
update_parameter_info(param, fplace, converted_model);
}
}
converted_model->validate_nodes_and_infer_types();
}

return converted_model;
}

Expand All @@ -149,10 +216,12 @@ void FrontEnd::convert(const std::shared_ptr<Model>& partiallyConverted) const {
}

std::shared_ptr<Model> FrontEnd::convert_partially(const ov::frontend::InputModel::Ptr& model) const {
FRONT_END_GENERAL_CHECK(std::dynamic_pointer_cast<pytorch::InputModel>(model), "Invalid input model");
auto pt_model = std::dynamic_pointer_cast<pytorch::InputModel>(model);
FRONT_END_GENERAL_CHECK(pt_model, "Invalid input model");
std::map<std::string, CreatorFunction> supported_ops = get_supported_ops(model);
std::shared_ptr<Model> partial_model;
{
pt_model->flush_places();
TranslateSession translate_session(model, supported_ops, m_telemetry);
partial_model = translate_session.get_converted_model();
}
Expand Down
56 changes: 55 additions & 1 deletion src/frontends/pytorch/src/input_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,41 @@ std::vector<ov::frontend::Place::Ptr> InputModel::get_outputs() const {
}

Place::Ptr InputModel::get_place_by_tensor_name(const std::string& tensor_name) const {
if (tensor_name.empty())
return {};
auto place_it = m_name_to_place.find(tensor_name);
if (place_it != m_name_to_place.end()) {
return place_it->second;
} else {
// Return fake place that can be used to change shape or type of inputs that will exist after conversion
auto place = std::make_shared<pytorch::Place>(*this, tensor_name, 0);
return place;
}
return nullptr;
}

Place::Ptr InputModel::get_place_by_input_index(size_t input_idx) const {
// Return place that can be used to change shape or type of inputs that will exist after conversion
auto place = std::make_shared<pytorch::Place>(*this, "", input_idx);
return place;
}

void InputModel::set_partial_shape(const Place::Ptr& place, const ov::PartialShape& shape) {
FRONT_END_GENERAL_CHECK(place && place->is_input(),
"Provided place is invalid, only inputs are supported for setting shape.");
auto pytorch_place = std::dynamic_pointer_cast<pytorch::Place>(place);
FRONT_END_GENERAL_CHECK(pytorch_place, "Only place produced by PyTorch Frontend is supported");
if (pytorch_place->m_is_fake) {
bool is_new = true;
for (auto& p : m_requested_places) {
if (p->is_equal(pytorch_place)) {
is_new = false;
pytorch_place = std::dynamic_pointer_cast<pytorch::Place>(p);
FRONT_END_GENERAL_CHECK(pytorch_place, "Only place produced by PyTorch Frontend is supported");
}
}
if (is_new)
m_requested_places.push_back(place);
}
pytorch_place->m_pshape = shape;
}

Expand All @@ -77,6 +100,18 @@ void InputModel::set_element_type(const Place::Ptr& place, const ov::element::Ty
"Provided place is invalid, only inputs are supported for setting element type.");
auto pytorch_place = std::dynamic_pointer_cast<pytorch::Place>(place);
FRONT_END_GENERAL_CHECK(pytorch_place, "Only place produced by PyTorch Frontend is supported");
if (pytorch_place->m_is_fake) {
bool is_new = true;
for (auto& p : m_requested_places) {
if (p->is_equal(pytorch_place)) {
is_new = false;
pytorch_place = std::dynamic_pointer_cast<pytorch::Place>(p);
FRONT_END_GENERAL_CHECK(pytorch_place, "Only place produced by PyTorch Frontend is supported");
}
}
if (is_new)
m_requested_places.push_back(place);
}
pytorch_place->m_type = type;
}

Expand Down Expand Up @@ -171,6 +206,25 @@ std::shared_ptr<TorchDecoder> InputModel::get_decoder() const {
return m_model_decoder;
}

void InputModel::flush_places() {
auto input_places = get_inputs();
if (m_requested_places.size() > input_places.size())
return;
for (auto place : m_requested_places) {
auto pt_place = std::dynamic_pointer_cast<pytorch::Place>(place);
if (!pt_place || pt_place->get_input_index() >= input_places.size() || pt_place->m_names.size() != 0)
return;
auto to_update_place = std::dynamic_pointer_cast<pytorch::Place>(input_places[pt_place->get_input_index()]);
if (!to_update_place || !to_update_place->m_type.is_dynamic() || !to_update_place->m_pshape.rank().is_dynamic())
return;
if (pt_place->m_type.is_static())
to_update_place->m_type = pt_place->m_type;
if (pt_place->m_pshape.rank().is_static())
to_update_place->m_pshape = pt_place->m_pshape;
}
m_requested_places = {};
}

} // namespace pytorch
} // namespace frontend
} // namespace ov
6 changes: 6 additions & 0 deletions src/frontends/pytorch/src/input_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace ov {
namespace frontend {
namespace pytorch {

class FrontEnd;
class TranslateSession;
class Place;
class TorchDecoder;
Expand All @@ -23,13 +24,15 @@ struct PlaceDesc {

class InputModel : public ov::frontend::InputModel {
friend class ::ov::frontend::pytorch::TranslateSession;
friend class ::ov::frontend::pytorch::FrontEnd;

public:
explicit InputModel(const std::shared_ptr<TorchDecoder>& model_decoder);

std::vector<frontend::Place::Ptr> get_inputs() const override;
std::vector<frontend::Place::Ptr> get_outputs() const override;
frontend::Place::Ptr get_place_by_tensor_name(const std::string& tensor_name) const override;
frontend::Place::Ptr get_place_by_input_index(size_t input_idx) const override;
void set_partial_shape(const frontend::Place::Ptr& place, const ov::PartialShape& shape) override;
ov::PartialShape get_partial_shape(const frontend::Place::Ptr& place) const override;
void set_element_type(const frontend::Place::Ptr& place, const ov::element::Type& type) override;
Expand All @@ -39,12 +42,15 @@ class InputModel : public ov::frontend::InputModel {
void override_all_inputs(const std::vector<frontend::Place::Ptr>& inputs) override;
const std::string& decoder_type_name() const;
std::shared_ptr<TorchDecoder> get_decoder() const;
// update input places and erase requested places if possible
void flush_places();

private:
std::shared_ptr<TorchDecoder> m_model_decoder;
std::unordered_map<std::string, std::shared_ptr<frontend::Place>> m_name_to_place;
std::vector<std::shared_ptr<frontend::Place>> m_inputs;
std::vector<std::shared_ptr<frontend::Place>> m_outputs;
std::vector<std::shared_ptr<frontend::Place>> m_requested_places;
std::unordered_map<size_t, PlaceDesc> m_descriptors;
};

Expand Down
26 changes: 26 additions & 0 deletions src/frontends/pytorch/src/place.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Place::Place(const ov::frontend::InputModel& input_model, size_t tensor_index)
auto in_it = std::find(inputs.begin(), inputs.end(), tensor_index);
if (in_it != inputs.end()) {
m_is_input = true;
m_input_index = std::distance(inputs.begin(), in_it);
auto idx = std::distance(inputs.begin(), in_it);
const auto& signature_name = decoder->get_input_signature_name(idx);
m_names.push_back(signature_name);
Expand Down Expand Up @@ -55,6 +56,31 @@ Place::Place(const ov::frontend::InputModel& input_model, size_t tensor_index)
}
}

Place::Place(const ov::frontend::InputModel& input_model, const std::string& name, size_t input_index)
: m_input_model(input_model),
m_tensor_index(0),
m_is_fake(true),
m_input_index(input_index),
m_pshape(PartialShape::dynamic()),
m_type(element::dynamic),
m_is_input(true) {
if (!name.empty())
m_names = {name};
}

bool Place::is_equal(const Ptr& another) const {
const auto& pt_place = std::dynamic_pointer_cast<pytorch::Place>(another);
if (!pt_place)
return this == another.get();
if (m_is_fake || pt_place->m_is_fake) {
if ((m_is_fake && m_names.size() != 0) || (pt_place->m_is_fake && pt_place->m_names.size() != 0))
// named fake place can only be equal to itself
return this == another.get();
return m_input_index == pt_place->get_input_index();
}
return this == another.get();
}

} // namespace pytorch
} // namespace frontend
} // namespace ov
Loading

0 comments on commit a3d2b6a

Please sign in to comment.