From a090abbc92aad7569a691bee33370ac831f471dd Mon Sep 17 00:00:00 2001 From: Nikita Malinin Date: Thu, 17 Feb 2022 18:17:07 +0300 Subject: [PATCH] Update remove_converts pass with shape inference (#10474) --- tools/pot/openvino/tools/pot/graph/passes.py | 23 ++++++++++---------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tools/pot/openvino/tools/pot/graph/passes.py b/tools/pot/openvino/tools/pot/graph/passes.py index d7011f58c4b301..2b2002250bd037 100644 --- a/tools/pot/openvino/tools/pot/graph/passes.py +++ b/tools/pot/openvino/tools/pot/graph/passes.py @@ -938,17 +938,18 @@ def find_shape_subgraph_endpoints(out_ports: List[Port], visited: set = None) -> def remove_converts(graph: Graph): - for op in graph.get_op_nodes(type='Convert'): - source_op = op.in_port(0).get_source().node - if source_op.type == 'Const' and source_op.data_type == np.float16: - # Get access to data node after Convert operation and set Insert_Convert_operation_after - # to restore Convert operation later - op.out_node(0)['Insert_Convert_operation_after'] = True - # Mark Const and Convert operation to fold them - source_op['need_shape_inference'] = True - op.out_node(0)['old_rt_info'] = op['rt_info'] - op['stop_value_propagation'] = False - op['need_shape_inference'] = True + for op in graph.get_op_nodes(): + if op.type == 'Convert': + source_op = op.in_port(0).get_source().node + if source_op.type == 'Const' and source_op.data_type == np.float16: + # Get access to data node after Convert operation and set Insert_Convert_operation_after + # to restore Convert operation later + op.out_node(0)['Insert_Convert_operation_after'] = True + # Mark Const and Convert operation to fold them + source_op['need_shape_inference'] = True + op.out_node(0)['old_rt_info'] = op['rt_info'] + op['stop_value_propagation'] = False + op['need_shape_inference'] = True graph.clean_up()