Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] refactor based on the new launch approach #3185

Merged
merged 80 commits into from
Dec 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
80ca92e
init commit of pytorch code converter
QuanluZhang Nov 1, 2020
d08d0a1
update
QuanluZhang Nov 3, 2020
fde8565
Merge branch 'dev-retiarii' of https://github.com/microsoft/nni into …
QuanluZhang Nov 3, 2020
a01ecfe
refactor graph ir
QuanluZhang Nov 4, 2020
2639080
update
QuanluZhang Nov 5, 2020
de9e801
Merge branch 'dev-retiarii' of https://github.com/microsoft/nni into …
QuanluZhang Nov 5, 2020
8874051
resolve merge conflict
QuanluZhang Nov 5, 2020
0c54c4e
update
QuanluZhang Nov 5, 2020
4b9f0d6
Merge branch 'dev-retiarii' of https://github.com/microsoft/nni into …
QuanluZhang Nov 5, 2020
48cf595
pass mnist example without code converter
QuanluZhang Nov 6, 2020
169e6b2
end to end
QuanluZhang Nov 10, 2020
a79c3be
fix bugs in pytorch code generation
QuanluZhang Nov 12, 2020
ae33a0a
fix code generating
QuanluZhang Nov 13, 2020
f2ad754
support apply mutators
QuanluZhang Nov 13, 2020
4640932
implement topo_sort
QuanluZhang Nov 13, 2020
c1d8ba9
generated the correct model
QuanluZhang Nov 14, 2020
bcc3cf1
end2end passed, strategy generates only one trial
QuanluZhang Nov 14, 2020
b1a3228
pass end2end, multiple trials
QuanluZhang Nov 15, 2020
7a642a1
update
QuanluZhang Nov 15, 2020
92bdd7a
merge input_names to input_node, passed test
QuanluZhang Nov 19, 2020
2729846
remove convert_name from code gen, add the logic in graph gen
QuanluZhang Nov 19, 2020
c0ed256
update
QuanluZhang Nov 19, 2020
d9e9260
Merge branch 'dev-retiarii' of https://github.com/microsoft/nni into …
QuanluZhang Nov 19, 2020
ec19bec
fix conflict
QuanluZhang Nov 19, 2020
1b557a5
refactor io node
QuanluZhang Nov 20, 2020
c8ffed8
first draft
liuzhe-lz Nov 20, 2020
f2dd604
Merge branch 'dev-retiarii' of https://github.com/microsoft/nni into …
QuanluZhang Nov 20, 2020
13de09e
init commit, support layerchoice, inputchoice
QuanluZhang Nov 20, 2020
4873cde
support layerchoice and inputchoice
QuanluZhang Nov 24, 2020
bfbec3d
second ver
liuzhe-lz Nov 25, 2020
e8648eb
refactor logging
liuzhe-lz Nov 26, 2020
bacd496
fix cluster metadata
liuzhe-lz Nov 26, 2020
4307ee8
clean up
liuzhe-lz Nov 26, 2020
b177b79
use foreground in example
liuzhe-lz Nov 26, 2020
e38278d
Merge branch 'master' into exp
liuzhe-lz Nov 26, 2020
97d370c
add missing file
liuzhe-lz Nov 26, 2020
8c55d21
fix pylint
liuzhe-lz Nov 27, 2020
7f96326
update ts timestamp to match python format
liuzhe-lz Nov 27, 2020
b052411
try to fix ts version differnce
liuzhe-lz Nov 27, 2020
697d8e6
generate darts code
QuanluZhang Nov 29, 2020
2c14267
Merge pull request #6 from liuzhe-lz/exp
QuanluZhang Nov 29, 2020
849d033
new launching approach
QuanluZhang Dec 1, 2020
07658b3
minor
QuanluZhang Dec 1, 2020
2dc6c5d
Merge branch 'dev-retiarii' of https://github.com/microsoft/nni into …
QuanluZhang Dec 1, 2020
c67a6ea
support instantiated trainer
QuanluZhang Dec 1, 2020
5deec9c
remove comments
QuanluZhang Dec 1, 2020
0f0db1f
refactor strategy
QuanluZhang Dec 1, 2020
67a161b
refactor user code, support with statement
QuanluZhang Dec 2, 2020
ef2fe7e
minor
QuanluZhang Dec 2, 2020
563a9c0
new experiment config for NAS, support tpe strategy for NAS
QuanluZhang Dec 5, 2020
32b9bec
Merge branch 'dev-retiarii' of https://github.com/microsoft/nni into …
QuanluZhang Dec 5, 2020
923ae26
refactor of code converter
QuanluZhang Dec 6, 2020
6e079fb
update code gen to shorten variable name for improving readability
QuanluZhang Dec 6, 2020
dd313e3
handle module list
QuanluZhang Dec 6, 2020
1ae325d
merge aten::slice
QuanluZhang Dec 6, 2020
4435973
deal with aten::append differently, as it has not output
QuanluZhang Dec 7, 2020
e4b94af
refactor
QuanluZhang Dec 10, 2020
e1a6b7b
resolve comments
QuanluZhang Dec 10, 2020
e6d3874
remove files
QuanluZhang Dec 11, 2020
9434511
minor
QuanluZhang Dec 11, 2020
0d12f34
Merge branch 'dev-retiarii' of https://github.com/microsoft/nni into …
QuanluZhang Dec 11, 2020
0ebae32
Merge branch 'dev-retiarii' of https://github.com/microsoft/nni into …
QuanluZhang Dec 11, 2020
c9b20cd
Merge branch 'dev-retiarii' of https://github.com/microsoft/nni into …
QuanluZhang Dec 11, 2020
1fdfce7
refactor based on the new launch method
QuanluZhang Dec 11, 2020
9744788
remove comments
QuanluZhang Dec 11, 2020
81c29ac
support decorator for trainer
QuanluZhang Dec 11, 2020
8a7707e
minor
QuanluZhang Dec 11, 2020
14aaf29
finalize package import based on new launch approach
QuanluZhang Dec 11, 2020
731543a
refactor of type class
QuanluZhang Dec 11, 2020
a4241a6
minor
QuanluZhang Dec 11, 2020
14abf88
support the whole nn.module
QuanluZhang Dec 12, 2020
256e89a
add user-friendly error message
QuanluZhang Dec 12, 2020
2e376ba
change folder name
QuanluZhang Dec 12, 2020
a44ec26
rename
QuanluZhang Dec 12, 2020
e0c82b4
update error message
QuanluZhang Dec 12, 2020
b5334b6
support darts (classic mode) with new launch approach, refactor layer…
QuanluZhang Dec 13, 2020
c02d620
minor
QuanluZhang Dec 13, 2020
cd0804e
add darts example
QuanluZhang Dec 13, 2020
e9fb681
layerchoice/inputchoice backward compatibility; improve handling of i…
QuanluZhang Dec 14, 2020
fb67d3a
resolve comments
QuanluZhang Dec 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nni/experiment/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, experiment_id: str):

def connect(self) -> BufferedIOBase:
conn, _ = self._socket.accept()
self.file = conn.makefile('w+b')
self.file = conn.makefile('rwb')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#3183 @liuzhe-lz can answer this question

return self.file

def close(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions nni/retiarii/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .graph import *
from .execution import *
from .mutator import *
from .utils import register_module
7 changes: 1 addition & 6 deletions nni/retiarii/codegen/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def model_to_pytorch_script(model: Model, placement = None) -> str:
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement = placement)
graphs.append(graph_code)
total_pkgs.update(import_pkgs)
# FIXME: set correct PATH for the packages (after launch refactor)
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()

Expand Down Expand Up @@ -71,7 +70,7 @@ def _remove_prefix(names, graph_name):
return names[len(graph_name):] if names.startswith(graph_name) else names

def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> str:
nodes = graph.topo_sort() # FIXME: topological sort is needed here
nodes = graph.topo_sort()

# handle module node and function node differently
# only need to generate code for module here
Expand Down Expand Up @@ -130,10 +129,6 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> s
import torch.nn.functional as F
import torch.optim as optim

# FIXME: remove these two lines
import sys
sys.path.append("test/convert_test")

{}

{}
Expand Down
82 changes: 62 additions & 20 deletions nni/retiarii/converter/graph_gen.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import json_tricks
import logging
import re
import torch

from ..graph import Graph, Node, Edge, Model
from ..operation import Cell, Operation
from ..nn.pytorch import Placeholder, LayerChoice, InputChoice

from .op_types import MODULE_EXCEPT_LIST, Type
from .op_types import MODULE_EXCEPT_LIST, OpTypeName, BasicOpsPT
from .utils import build_full_name, _convert_name

_logger = logging.getLogger(__name__)

global_seq = 0
global_graph_id = 0
Expand Down Expand Up @@ -80,7 +82,7 @@ def create_prim_constant_node(ir_graph, node, module_name):
if node.outputsAt(0).toIValue() is not None:
attrs = {'value': node.outputsAt(0).toIValue()}
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, Type.Constant, global_seq),
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Constant, global_seq),
node.kind(), attrs)
return new_node

Expand Down Expand Up @@ -163,6 +165,33 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
# key: tensor (%out.1), value: node (this node)
output_remap = {}

def handle_if_condition(cond_tensor):
"""
to calculate the condition, we only deal with the following op types by tracing back
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`

generate the expression using recursive calls

NOTE: do not support dynamic graph
"""
def _generate_expr(tensor):
if tensor.node().kind() == 'prim::GetAttr':
return f'({getattr(module, tensor.node().s("name"))})'
elif tensor.node().kind() == 'aten::__getitem__':
t = _generate_expr(tensor.node().inputsAt(0))
idx = _generate_expr(tensor.node().inputsAt(1))
return f'({t}[{idx}])'
elif tensor.node().kind() == 'prim::Constant':
return f'{tensor.toIValue()}'
elif tensor.node().kind() == 'aten::eq':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} == {right})'
else:
raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition')
expr = _generate_expr(cond_tensor)
return eval(expr)

def handle_if_node(node):
"""
Parameters
Expand All @@ -179,19 +208,13 @@ def handle_if_node(node):
# will support constant expression in future
inputs = [i for i in node.inputs()]
assert len(inputs) == 1
if not inputs[0].node().kind() in ['prim::Constant', 'prim::GetAttr']:
raise RuntimeError('"if" whose condition is not constant or attribute has not been supported yet!')
chosen_block = None
if inputs[0].node().kind() == 'prim::Constant':
chosen_block = 0 if inputs[0].toIValue() else 1
if inputs[0].node().kind() == 'prim::GetAttr':
chosen_block = 0 if getattr(module, inputs[0].node().s('name')) else 1
cond = handle_if_condition(inputs[0])
chosen_block = 0 if cond else 1
blocks = [block for block in node.blocks()]
assert len(blocks) == 2
last_block_node = None
for node in blocks[chosen_block].nodes():
last_block_node = handle_single_node(node)
assert last_block_node is not None
return last_block_node

def handle_single_node(node):
Expand Down Expand Up @@ -287,29 +310,33 @@ def handle_single_node(node):
node_index[node] = new_node
elif node.kind() == 'prim::ListConstruct':
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, Type.ListConstruct, global_seq), node.kind())
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.ListConstruct, global_seq), node.kind())
node_index[node] = new_node
_add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'aten::append':
global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, Type.BasicOpsPT[node.kind()], global_seq), node.kind())
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], global_seq), node.kind())
node_index[node] = aten_node
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
output_remap[node.inputsAt(0)] = node
elif node.kind().startswith('aten::'):
# handle aten::XXX
global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, Type.BasicOpsPT[node.kind()], global_seq), node.kind())
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], global_seq), node.kind())
node_index[node] = aten_node
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
elif node.kind() == 'prim::GetAttr':
node_type, attrs = handle_prim_attr_node(node)
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, Type.Attr, global_seq),
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, global_seq),
node_type, attrs)
node_index[node] = new_node
elif node.kind() == 'prim::min':
print('zql: ', sm_graph)
exit(1)
elif node.kind() == 'prim::If':
last_block_node = handle_if_node(node)
# last_block_node is None means no node in the branch block
node_index[node] = last_block_node
elif node.kind() == 'prim::Loop':
raise RuntimeError('Loop has not been supported yet!')
Expand Down Expand Up @@ -343,7 +370,10 @@ def merge_aten_slices(ir_graph):

for head_node in head_slice_nodes:
slot = 0
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), Type.MergedSlice)
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
if len(head_node.incoming_edges) == 4:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
break
assert len(head_node.incoming_edges) == 5
for edge in head_node.incoming_edges:
edge.tail = new_slice_node
Expand Down Expand Up @@ -383,10 +413,13 @@ def _handle_layerchoice(module):

m_attrs = {}
candidates = module.candidate_ops
choices = []
for i, cand in enumerate(candidates):
assert id(cand) in modules_arg, 'id not exist: {}'.format(id(cand))
assert isinstance(modules_arg[id(cand)], dict)
m_attrs[f'choice_{i}'] = modules_arg[id(cand)]
cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__
choices.append({'type': cand_type, 'parameters': modules_arg[id(cand)]})
m_attrs[f'choices'] = choices
m_attrs['label'] = module.label
return m_attrs

Expand Down Expand Up @@ -425,17 +458,18 @@ def convert_module(script_module, module, module_name, ir_model):
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name
if original_type_name == Type.LayerChoice:
if original_type_name == OpTypeName.LayerChoice:
m_attrs = _handle_layerchoice(module)
return None, m_attrs
if original_type_name == Type.InputChoice:
if original_type_name == OpTypeName.InputChoice:
m_attrs = _handle_inputchoice(module)
return None, m_attrs
if original_type_name in Type.Placeholder:
if original_type_name == OpTypeName.Placeholder:
m_attrs = modules_arg[id(module)]
return None, m_attrs
if original_type_name in torch.nn.__dict__ and original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph
assert id(module) in modules_arg, f'{original_type_name} arguments are not recorded'
m_attrs = modules_arg[id(module)]
return None, m_attrs

Expand Down Expand Up @@ -463,7 +497,15 @@ def convert_module(script_module, module, module_name, ir_model):

ir_graph._register()

return ir_graph, modules_arg[id(module)]
if id(module) not in modules_arg:
raise RuntimeError(f'{original_type_name} arguments are not recorded, \
you might have forgotten to decorate this class with @register_module()')
# TODO: if we parse this module, it means we will create a graph (module class)
# for this module. Then it is not necessary to record this module's arguments
# return ir_graph, modules_arg[id(module)].
# That is, we can refactor this part, to allow users to annotate which module
# should not be parsed further.
return ir_graph, {}

def convert_to_graph(script_module, module, recorded_modules_arg):
"""
Expand Down
38 changes: 21 additions & 17 deletions nni/retiarii/converter/op_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from enum import Enum

MODULE_EXCEPT_LIST = ['Sequential']


class Type:
"""Node Type class
class OpTypeName(str, Enum):
"""
op type to its type name str
"""
Attr = 'Attr'
Constant = 'Constant'
Expand All @@ -11,21 +14,22 @@ class Type:
InputChoice = 'InputChoice'
ValueChoice = 'ValueChoice'
Placeholder = 'Placeholder'

MergedSlice = 'MergedSlice'

# deal with aten op
BasicOpsPT = {
'aten::mean': 'Mean',
'aten::relu': 'Relu',
'aten::add': 'Add',
'aten::__getitem__': 'getitem',
'aten::append': 'Append',
'aten::len': 'Len',
'aten::slice': 'Slice',
'aten::cat': 'Cat',
'aten::size': 'Size',
'aten::view': 'View'
}
# deal with aten op
BasicOpsPT = {
'aten::mean': 'Mean',
'aten::relu': 'Relu',
'aten::add': 'Add',
'aten::__getitem__': 'getitem',
'aten::append': 'Append',
'aten::len': 'Len',
'aten::slice': 'Slice',
'aten::cat': 'Cat',
'aten::size': 'Size',
'aten::view': 'View',
'aten::eq': 'Eq',
'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}

BasicOpsTF = {}
BasicOpsTF = {}
7 changes: 0 additions & 7 deletions nni/retiarii/execution/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,6 @@ def get_and_register_default_listener(engine: AbstractExecutionEngine) -> Defaul
engine.register_graph_listener(_default_listener)
return _default_listener

def _get_search_space() -> 'Dict':
engine = get_execution_engine()
while True:
time.sleep(1)
if engine.get_search_space() is not None:
break
return engine.get_search_space()

def submit_models(*models: Model) -> None:
engine = get_execution_engine()
Expand Down
5 changes: 0 additions & 5 deletions nni/retiarii/execution/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ def __init__(self) -> None:

self._running_models: Dict[int, Model] = dict()

def get_search_space(self) -> 'JSON':
advisor = get_advisor()
return advisor.search_space

def submit_models(self, *models: Model) -> None:
for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model),
Expand Down Expand Up @@ -106,7 +102,6 @@ def trial_execute_graph(cls) -> None:
Initialize the model, hand it over to trainer.
"""
graph_data = BaseGraphData.load(receive_trial_parameters())
# FIXME: update this part to dump code to a correct path!!!
with open('_generated_model.py', 'w') as f:
f.write(graph_data.model_script)
trainer_cls = utils.import_(graph_data.training_module)
Expand Down
Loading