Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
[Retiarii] pytorch code converter (#3052)
Browse files Browse the repository at this point in the history
  • Loading branch information
QuanluZhang authored Nov 18, 2020
1 parent 002af91 commit 8af7314
Show file tree
Hide file tree
Showing 27 changed files with 1,612 additions and 69 deletions.
33 changes: 33 additions & 0 deletions nni/experiment/nni_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import re
import json
import requests
import yaml

__all__ = [
'Experiment',
Expand Down Expand Up @@ -265,6 +266,38 @@ def _exec_command(self, cmd, port=None):
self._endpoint = 'http://localhost:{}'.format(self._port)
self._exp_id = self.get_experiment_profile()['id']

def tmp_start_retiarii(self, graph_ir, training_approach,
applied_mutators, strategy, exp_config):
# prepare search space file which includes base graph IR and mutators
search_space = {}
search_space['base_model_ir'] = graph_ir
search_space['applied_mutators'] = applied_mutators
search_space['training_approach'] = training_approach
with open('search_space.json', 'w') as f:
json.dump(search_space, f)
# add advisor config to exp_config
exp_config['searchSpacePath'] = 'search_space.json'
exp_config['useAnnotation'] = False
exp_config['advisor'] = {
'codeDir': '.',
'classFileName': 'advisor_entry.py',
'className': 'RetiariiAdvisor',
'classArgs': {
'strategy': '{}.{}'.format(strategy['filename'], strategy['funcname'])
}
}
# add trial config to exp_config
exp_config['trial'] = {
'command': 'python3 -m nni.retiarii.trial_entry',
'codeDir': '../..',
'gpuNum': 0
}
# dump exp_config to nni.yml
with open('nni.yml', 'w') as f:
yaml.dump(exp_config, f)
# start experiment
self.start_experiment('nni.yml')

def start_experiment(self, config_file, port=None, debug=False):
"""
Start an experiment with specified configuration file and connect to it.
Expand Down
2 changes: 1 addition & 1 deletion nni/retiarii/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .execution import *
from .graph import *
from .mutator import *
from .operation import *
from .model_apis import nn
62 changes: 45 additions & 17 deletions nni/retiarii/codegen/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,27 @@
from ..graph import IllegalGraphError, Edge, Graph, Node, Model
from ..operation import Operation, Cell

# TODO: fix: inputs is a list, how to deal with single element list and single element

def model_to_pytorch_script(model: Model) -> str:
graphs = [graph_to_pytorch_model(name, cell) for name, cell in model.graphs.items()]
return _PyTorchScriptTemplate.format('\n\n'.join(graphs)).strip()

graphs = []
total_pkgs = set()
for name, cell in model.graphs.items():
import_pkgs, graph_code = graph_to_pytorch_model(name, cell)
graphs.append(graph_code)
total_pkgs.update(import_pkgs)
# TODO: set correct PATH for the packages (after launch refactor)
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()

def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')

def _convert_names(names: List[str]) -> List[str]:
return [_convert_name(name) for name in names]

def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
Expand All @@ -21,8 +37,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))


def _format_inputs(node: Node) -> str:
def _format_inputs(node: Node) -> List[str]:
edges = _sorted_incoming_edges(node)
inputs = []
for edge in edges:
Expand All @@ -41,40 +56,48 @@ def _format_inputs(node: Node) -> str:
else:
# when the input comes from a multi-output operator: needs to know which one it comes from
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
return ', '.join(inputs)

return inputs

def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
nodes = graph.nodes # FIXME: topological sort is needed here

# handle module node and function node differently
# only need to generate code for module here
import_pkgs = set()
node_codes = []
for node in nodes:
if node.operation:
node_codes.append(node.operation.to_init_code(node.name))
pkg_name = node.operation.get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
node_code = node.operation.to_init_code(node.name)
if node_code is not None:
node_codes.append(node_code)

if graph.input_names is None:
input_code = '*_inputs'
else:
input_code = ', '.join(graph.input_names)
# TODO: remove _convert_names (after merging input_names and input_node)
input_code = ', '.join(_convert_names(graph.input_names))

edge_codes = []

for node in nodes:
sorted_nodes = graph.topo_sort()
for node in sorted_nodes:
if node.operation:
inputs = _format_inputs(node)
edge_codes.append(node.operation.to_forward_code(node.name, node.name, inputs))

output_code = _format_inputs(graph.output_node)
if not output_code:
output_code = 'None'
# TODO: refactor graph output_node
output_names = _format_inputs(graph.output_node)
output_names = _convert_names(output_names)
if not output_names:
output_names = ['None']

linebreak = '\n '
return _PyTorchModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name),
return import_pkgs, _PyTorchModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else _convert_name(graph_name)),
inputs=input_code,
outputs=output_code,
outputs=', '.join(output_names),
nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes)
)
Expand All @@ -88,6 +111,11 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
import torch.nn.functional as F
import torch.optim as optim
import sys
sys.path.append("test/convert_test")
{}
{}
'''

Expand Down
37 changes: 37 additions & 0 deletions nni/retiarii/converter/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# PyTorch Graph Converter

## Namespace for PyTorch Graph

We should have a concrete rule for specifying nodes in graph with namespace.

Each node has a name, either specified or generated. The nodes in the same hierarchy cannot have the same name.

* The name of module node natively follows this rule, because we use variable name for instantiated modules like what PyTorch graph does.

* For the nodes created in `forward` function, we use a global sequence number.

### Namespace for mutated (new) nodes

TBD

## Graph Simplification

TBD

## Node Types

We define concrete type string for each node type.

## Module's Input Arguments

We use wrapper to obtain the input arguments of modules. Users need to use our wrapped "nn" and wrapped "Module".

## Control Flow

### for loop

Currently, we only support `ModuleList` (`ModuleDict`) based for loop, which is automatically unfolded by TorchScript. That is to say, we do not support loop in TorchScript for now.

### if/else

For now, we only deal with the case that the condition is constant or attribute. In this case, only one branch is kept during generating the graph.
2 changes: 2 additions & 0 deletions nni/retiarii/converter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .graph_gen import convert_to_graph
from .visualize import visualize_model
Loading

0 comments on commit 8af7314

Please sign in to comment.