Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
[Retiarii] Coding style improvements for pylint and flake8 (#3190)
Browse files Browse the repository at this point in the history
  • Loading branch information
ultmaster authored Dec 14, 2020
1 parent 593a275 commit 59cd398
Show file tree
Hide file tree
Showing 34 changed files with 221 additions and 199 deletions.
19 changes: 10 additions & 9 deletions nni/retiarii/codegen/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import logging
from typing import *
from typing import List

from ..graph import IllegalGraphError, Edge, Graph, Node, Model
from ..operation import Operation, Cell

_logger = logging.getLogger(__name__)



def model_to_pytorch_script(model: Model, placement = None) -> str:
def model_to_pytorch_script(model: Model, placement=None) -> str:
graphs = []
total_pkgs = set()
for name, cell in model.graphs.items():
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement = placement)
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement=placement)
graphs.append(graph_code)
total_pkgs.update(import_pkgs)
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()


def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.info('sorted_incoming_edges: {}'.format(edges))
_logger.info('sorted_incoming_edges: %s', str(edges))
if not edges:
return []
_logger.info(f'all tail_slots are None: {[edge.tail_slot for edge in edges]}')
_logger.info('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
if all(edge.tail_slot is None for edge in edges):
return edges
if all(isinstance(edge.tail_slot, int) for edge in edges):
Expand All @@ -32,6 +31,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))


def _format_inputs(node: Node) -> List[str]:
edges = _sorted_incoming_edges(node)
inputs = []
Expand All @@ -53,6 +53,7 @@ def _format_inputs(node: Node) -> List[str]:
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
return inputs


def _remove_prefix(names, graph_name):
"""
variables name (full name space) is too long,
Expand All @@ -69,14 +70,14 @@ def _remove_prefix(names, graph_name):
else:
return names[len(graph_name):] if names.startswith(graph_name) else names

def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> str:

def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str:
nodes = graph.topo_sort()

# handle module node and function node differently
# only need to generate code for module here
import_pkgs = set()
node_codes = []
placement_codes = []
for node in nodes:
if node.operation:
pkg_name = node.operation.get_import_pkg()
Expand Down
1 change: 0 additions & 1 deletion nni/retiarii/converter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .graph_gen import convert_to_graph
from .visualize import visualize_model
47 changes: 29 additions & 18 deletions nni/retiarii/converter/graph_gen.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import json_tricks
import logging
import re
import torch

from ..graph import Graph, Node, Edge, Model
from ..operation import Cell, Operation
from ..nn.pytorch import Placeholder, LayerChoice, InputChoice
import torch

from .op_types import MODULE_EXCEPT_LIST, OpTypeName, BasicOpsPT
from .utils import build_full_name, _convert_name
from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, LayerChoice, Placeholder
from ..operation import Cell
from .op_types import MODULE_EXCEPT_LIST, BasicOpsPT, OpTypeName
from .utils import _convert_name, build_full_name

_logger = logging.getLogger(__name__)

global_seq = 0
global_graph_id = 0
modules_arg = None


def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
"""
Parameters
Expand Down Expand Up @@ -76,6 +76,7 @@ def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap,

new_node_input_idx += 1


def create_prim_constant_node(ir_graph, node, module_name):
global global_seq
attrs = {}
Expand All @@ -86,14 +87,17 @@ def create_prim_constant_node(ir_graph, node, module_name):
node.kind(), attrs)
return new_node


def handle_prim_attr_node(node):
assert node.hasAttribute('name')
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()}
return node.kind(), attrs


def _remove_mangle(module_type_str):
return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)


def remove_unconnected_nodes(ir_graph, targeted_type=None):
"""
Parameters
Expand Down Expand Up @@ -122,6 +126,7 @@ def remove_unconnected_nodes(ir_graph, targeted_type=None):
for hidden_node in to_removes:
hidden_node.remove()


def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, ir_graph):
"""
Convert torch script node to our node ir, and build our graph ir
Expand Down Expand Up @@ -156,7 +161,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
# TODO: add scope name
ir_graph._add_input(_convert_name(_input.debugName()))

node_index = {} # graph node to graph ir node
node_index = {} # graph node to graph ir node

# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
Expand Down Expand Up @@ -248,13 +253,14 @@ def handle_single_node(node):
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(submodule_name, script_module._modules.keys())
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(
submodule_name, script_module._modules.keys())

submodule_full_name = build_full_name(module_name, submodule_name)
submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[submodule_name],
submodule_obj,
submodule_full_name, ir_model)
submodule_obj,
submodule_full_name, ir_model)
else:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
Expand All @@ -271,7 +277,7 @@ def handle_single_node(node):
predecessor_obj = getattr(module, predecessor_name)
submodule_obj = getattr(predecessor_obj, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[predecessor_name]._modules[submodule_name],
submodule_obj, submodule_full_name, ir_model)
submodule_obj, submodule_full_name, ir_model)
else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))

Expand Down Expand Up @@ -329,7 +335,7 @@ def handle_single_node(node):
node_type, attrs = handle_prim_attr_node(node)
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, global_seq),
node_type, attrs)
node_type, attrs)
node_index[node] = new_node
elif node.kind() == 'prim::min':
print('zql: ', sm_graph)
Expand All @@ -350,6 +356,7 @@ def handle_single_node(node):

return node_index


def merge_aten_slices(ir_graph):
"""
if there is aten::slice node, merge the consecutive ones together.
Expand All @@ -367,7 +374,7 @@ def merge_aten_slices(ir_graph):
break
if has_slice_node:
assert head_slice_nodes

for head_node in head_slice_nodes:
slot = 0
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
Expand All @@ -391,11 +398,11 @@ def merge_aten_slices(ir_graph):
slot += 4
ir_graph.hidden_nodes.remove(node)
node = suc_node

for edge in node.outgoing_edges:
edge.head = new_slice_node
ir_graph.hidden_nodes.remove(node)


def refine_graph(ir_graph):
"""
Expand All @@ -408,13 +415,14 @@ def refine_graph(ir_graph):
remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
merge_aten_slices(ir_graph)


def _handle_layerchoice(module):
global modules_arg

m_attrs = {}
candidates = module.candidate_ops
choices = []
for i, cand in enumerate(candidates):
for cand in candidates:
assert id(cand) in modules_arg, 'id not exist: {}'.format(id(cand))
assert isinstance(modules_arg[id(cand)], dict)
cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__
Expand All @@ -423,13 +431,15 @@ def _handle_layerchoice(module):
m_attrs['label'] = module.label
return m_attrs


def _handle_inputchoice(module):
m_attrs = {}
m_attrs['n_chosen'] = module.n_chosen
m_attrs['reduction'] = module.reduction
m_attrs['label'] = module.label
return m_attrs


def convert_module(script_module, module, module_name, ir_model):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
Expand Down Expand Up @@ -503,10 +513,11 @@ def convert_module(script_module, module, module_name, ir_model):
# TODO: if we parse this module, it means we will create a graph (module class)
# for this module. Then it is not necessary to record this module's arguments
# return ir_graph, modules_arg[id(module)].
# That is, we can refactor this part, to allow users to annotate which module
# That is, we can refactor this part, to allow users to annotate which module
# should not be parsed further.
return ir_graph, {}


def convert_to_graph(script_module, module, recorded_modules_arg):
"""
Convert module to our graph ir, i.e., build a ```Model``` type
Expand Down
5 changes: 3 additions & 2 deletions nni/retiarii/converter/op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class OpTypeName(str, Enum):
Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice'


# deal with aten op
BasicOpsPT = {
'aten::mean': 'Mean',
Expand All @@ -29,7 +30,7 @@ class OpTypeName(str, Enum):
'aten::size': 'Size',
'aten::view': 'View',
'aten::eq': 'Eq',
'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}

BasicOpsTF = {}
BasicOpsTF = {}
1 change: 1 addition & 0 deletions nni/retiarii/converter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ def build_full_name(prefix, name, seq=None):
else:
return '{}__{}{}'.format(prefix, name, str(seq))


def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
Expand Down
4 changes: 3 additions & 1 deletion nni/retiarii/converter/visualize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import graphviz


def convert_to_visualize(graph_ir, vgraph):
for name, graph in graph_ir.items():
if name == '_training_config':
Expand Down Expand Up @@ -33,7 +34,8 @@ def convert_to_visualize(graph_ir, vgraph):
dst = cell_node[dst][0]
subgraph.edge(src, dst)


def visualize_model(graph_ir):
vgraph = graphviz.Digraph('G', filename='vgraph', format='jpg')
convert_to_visualize(graph_ir, vgraph)
vgraph.render()
vgraph.render()
5 changes: 2 additions & 3 deletions nni/retiarii/execution/api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import time
import os
import importlib.util
from typing import *
from typing import List

from ..graph import Model, ModelStatus
from .base import BaseExecutionEngine
from .cgo_engine import CGOExecutionEngine
from .interface import *
from .interface import AbstractExecutionEngine, WorkerInfo
from .listener import DefaultListener

_execution_engine = None
Expand Down
8 changes: 4 additions & 4 deletions nni/retiarii/execution/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import *
from typing import Dict, Any, List

from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
Expand Down Expand Up @@ -61,16 +61,16 @@ def register_graph_listener(self, listener: AbstractGraphListener) -> None:

def _send_trial_callback(self, paramater: dict) -> None:
for listener in self._listeners:
_logger.warning('resources: {}'.format(listener.resources))
_logger.warning('resources: %s', listener.resources)
if not listener.has_available_resource():
_logger.warning('There is no available resource, but trial is submitted.')
listener.on_resource_used(1)
_logger.warning('on_resource_used: {}'.format(listener.resources))
_logger.warning('on_resource_used: %s', listener.resources)

def _request_trial_jobs_callback(self, num_trials: int) -> None:
for listener in self._listeners:
listener.on_resource_available(1 * num_trials)
_logger.warning('on_resource_available: {}'.format(listener.resources))
_logger.warning('on_resource_available: %s', listener.resources)

def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id]
Expand Down
Loading

0 comments on commit 59cd398

Please sign in to comment.