Skip to content

Commit

Permalink
Support nested ModuleList and fix an issue in list append (microsoft#…
Browse files Browse the repository at this point in the history
  • Loading branch information
QuanluZhang authored May 22, 2021
1 parent ac14b9e commit 9444e27
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 13 deletions.
41 changes: 28 additions & 13 deletions nni/retiarii/converter/graph_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ def __init__(self):
self.global_graph_id = 0

def _add_edge_handle_source_node(self, _input, graph_inputs, ir_graph, output_remap, node_index):
if _input in graph_inputs:
idx = graph_inputs.index(_input)
src_node = ir_graph.input_node
src_node_idx = idx
elif _input in output_remap:
if _input in output_remap:
assert output_remap[_input].kind() == 'aten::append'
predecessor_node = output_remap[_input]
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
src_node_idx = None
src_node = node_index[predecessor_node]
assert isinstance(src_node, Node)
elif _input in graph_inputs:
idx = graph_inputs.index(_input)
src_node = ir_graph.input_node
src_node_idx = idx
else:
predecessor_node = _input.node()
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
Expand Down Expand Up @@ -315,16 +315,31 @@ def handle_function_callmethod(node):
if submodule.inputsAt(0).type().name() == 'ModuleList':
# handle ModuleList
predecessor = submodule.inputsAt(0).node()
module_name_space = [submodule_name]
while predecessor.inputsAt(0).debugName() != 'self':
# this is for dealing with nested ModuleList. below is an example
# %3 : __torch__.torch.nn.modules.container.___torch_mangle_0.ModuleList = prim::GetAttr[name="ops"](%self)
# %5 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="0"](%3)
# %7 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="1"](%3)
# %9 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="2"](%3)
# %11 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="3"](%3)
# %14 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="0"](%5)
# %16 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="1"](%5)
# %state.2 : Tensor = prim::CallMethod[name="forward"](%14, %x.1) # modulelist.py:18:24
# %state.4 : Tensor = prim::CallMethod[name="forward"](%16, %state.2) # modulelist.py:18:24
assert predecessor.kind() == 'prim::GetAttr'
module_name_space.append(predecessor.s('name'))
predecessor = predecessor.inputsAt(0).node()
assert predecessor.kind() == 'prim::GetAttr'
assert predecessor.hasAttribute('name')
assert predecessor.inputsAt(0).debugName() == 'self'
predecessor_name = predecessor.s('name')
# TODO: exchange submodule_name and predecessor_name
submodule_full_name = build_full_name(module_name, [submodule_name, predecessor_name])
predecessor_obj = getattr(module, predecessor_name)
submodule_obj = getattr(predecessor_obj, submodule_name)
subgraph, sub_m_attrs = self.convert_module(script_module._modules[predecessor_name]._modules[submodule_name],
submodule_obj, submodule_full_name, ir_model)
module_name_space.append(predecessor.s('name'))
submodule_full_name = build_full_name(module_name, list(reversed(module_name_space)))
submodule_obj = module
script_submodule = script_module
for each_name in list(reversed(module_name_space)):
submodule_obj = getattr(submodule_obj, each_name)
script_submodule = script_submodule._modules[each_name]
subgraph, sub_m_attrs = self.convert_module(script_submodule, submodule_obj, submodule_full_name, ir_model)
else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))

Expand Down
91 changes: 91 additions & 0 deletions test/ut/retiarii/test_convert_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
import sys
import unittest
from typing import (Dict)

import numpy as np
import torch
import torch.nn.functional as F
import torchvision

import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script


class TestModels(unittest.TestCase):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result

def run_test(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_code = model_to_pytorch_script(model_ir)
print(model_code)

exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with torch.no_grad():
expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input)
if check_value:
try:
self.assertEqual(len(converted_output), len(expected_output))
for a, b in zip(converted_output, expected_output):
torch.eq(a, b)
except:
self.assertEqual(converted_output, expected_output)
return converted_model

def test_nested_modulelist(self):
class Net(nn.Module):
def __init__(self, num_nodes, num_ops_per_node):
super().__init__()
self.ops = nn.ModuleList()
self.num_nodes = num_nodes
self.num_ops_per_node = num_ops_per_node
for _ in range(num_nodes):
self.ops.append(nn.ModuleList([nn.Linear(16, 16) for __ in range(num_ops_per_node)]))

def forward(self, x):
state = x
for ops in self.ops:
for op in ops:
state = op(state)
return state

model = Net(4, 2)
x = torch.rand((16, 16), dtype=torch.float)
self.run_test(model, (x, ))

def test_append_input_tensor(self):
from typing import List
class Net(nn.Module):
def __init__(self, num_nodes):
super().__init__()
self.ops = nn.ModuleList()
self.num_nodes = num_nodes
for _ in range(num_nodes):
self.ops.append(nn.Linear(16, 16))
def forward(self, x: List[torch.Tensor]):
state = x
for ops in self.ops:
state.append(ops(state[-1]))
return state[-1]

model = Net(4)
x = torch.rand((1, 16), dtype=torch.float)
self.run_test(model, ([x], ))

0 comments on commit 9444e27

Please sign in to comment.