Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
[BUG] finding leaf modules (#2241)
Browse files Browse the repository at this point in the history
  • Loading branch information
QuanluZhang authored Mar 27, 2020
1 parent 5c8cb25 commit 6e62990
Showing 1 changed file with 28 additions and 31 deletions.
59 changes: 28 additions & 31 deletions src/sdk/pynni/nni/compression/speedup/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,42 +229,39 @@ def _extract_leaf_modules(self, graph):
list
a list of scope name of all the leaf modules
"""
pieces = [] # each element is a dict
class SNode:
def __init__(self, name):
self.sname = name
self.childs = {}

root = None
for node in graph.nodes():
scope_name = node.scopeName()
if scope_name == '':
continue
segs = scope_name.split('/')
segs_len = len(segs)
# increase the length of `pieces` if not enough
for _ in range(segs_len - len(pieces)):
pieces.append({})
# process internal segments of the scope name
# 'L' means leaf segment
# 'I' means internal segment
# internal segment can replace leaf segment at the same position of `pieces`
for i, seg in enumerate(segs[:-1]):
seg_name_dict = pieces[i]
if seg in seg_name_dict:
if seg_name_dict[seg][0] == 'L':
seg_name_dict[seg] = ('I', node)
else:
seg_name_dict[seg] = ('I', node)
# process the leaf segment of the scope name
last_segs_dict = pieces[len(segs) - 1]
if not segs[-1] in last_segs_dict:
last_segs_dict[segs[-1]] = ('L', node)
# traverse `pieces` to obtain all the leaf modules which are labeled with 'L'
leaf_modules = []
for piece in pieces:
for _, value in piece.items():
if value[0] == 'L':
assert value[1].scopeName() not in leaf_modules
# if this is a leaf module, the last segment of its scope name
# must be in pattern `xxx[xxx]`
if value[1].scopeName()[-1] == ']':
leaf_modules.append(value[1].scopeName())
return leaf_modules
if root is None:
root = SNode(segs[0])
curr = root
for seg in segs[1:]:
if not seg in curr.childs:
curr.childs[seg] = SNode(seg)
curr = curr.childs[seg]

leaf_nodes = []
def traverse_tree(node, scope_name):
if scope_name == '':
sn = node.sname
else:
sn = scope_name + '/' + node.sname
if not node.childs:
if node.sname[-1] == ']':
leaf_nodes.append(sn)
else:
for key in node.childs:
traverse_tree(node.childs[key], sn)
traverse_tree(root, '')
return leaf_nodes

def _build_graph(self):
"""
Expand Down

0 comments on commit 6e62990

Please sign in to comment.