Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Fix bug in graph converter from Graph_gen (#4092)
Browse files Browse the repository at this point in the history
  • Loading branch information
JiahangXu authored Sep 6, 2021
1 parent bf18854 commit 5f0a7c9
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
19 changes: 15 additions & 4 deletions nni/retiarii/converter/graph_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,17 @@ def handle_single_node(node):
for node in sm_graph.nodes():
handle_single_node(node)

if node_index == {}:
# here is an example that the ir_graph is empty
# graph(%self : __torch__.torchmodels.googlenet.GoogLeNet,
# %x.1 : Tensor): return (%x.1)
# add a noop_identity node to handle this situation
self.global_seq += 1
ni_node = ir_graph.add_node(build_full_name(module_name, 'noop_identity', self.global_seq), 'noop_identity')
ir_graph.add_edge(head=(ir_graph.input_node, 0), tail=(ni_node, None))
ir_graph.add_edge(head=(ni_node, None), tail=(ir_graph.output_node, None))
for _output in sm_graph.outputs():
node_index[_output.node()] = ni_node
return node_index

def merge_aten_slices(self, ir_graph):
Expand Down Expand Up @@ -575,9 +586,7 @@ def _convert_module(self, script_module, module, module_name, ir_model):
# also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name
m_attrs = None
if original_type_name in MODULE_EXCEPT_LIST:
pass # do nothing
elif original_type_name == OpTypeName.LayerChoice:
if original_type_name == OpTypeName.LayerChoice:
graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now
candidate_name_list = []
for cand_name in module.names:
Expand All @@ -599,7 +608,9 @@ def _convert_module(self, script_module, module, module_name, ir_model):
m_attrs = self._handle_valuechoice(module)
elif original_type_name == OpTypeName.Placeholder:
m_attrs = get_init_parameters_or_fail(module)
elif module.__class__.__module__.startswith('torch.nn') and original_type_name in torch.nn.__dict__:
elif module.__class__.__module__.startswith('torch.nn') and \
original_type_name in torch.nn.__dict__ and \
original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph
m_attrs = get_init_parameters_or_fail(module)
elif getattr(module, '_stop_parsing', False):
Expand Down
1 change: 1 addition & 0 deletions nni/retiarii/converter/op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from enum import Enum

# except the special case which can not treat as a basic module from pytorch
MODULE_EXCEPT_LIST = ['Sequential']


Expand Down
2 changes: 1 addition & 1 deletion nni/retiarii/nn/pytorch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

Module = nn.Module

Sequential = transparent_serialize(nn.Sequential)
Sequential = nn.Sequential
ModuleList = transparent_serialize(nn.ModuleList)

Identity = basic_unit(nn.Identity)
Expand Down

0 comments on commit 5f0a7c9

Please sign in to comment.