From 59cd3982997acfe554249abda25a11ec588b7ec1 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 14 Dec 2020 18:37:39 +0800 Subject: [PATCH] [Retiarii] Coding style improvements for pylint and flake8 (#3190) --- nni/retiarii/codegen/pytorch.py | 19 +++---- nni/retiarii/converter/__init__.py | 1 - nni/retiarii/converter/graph_gen.py | 47 ++++++++++------- nni/retiarii/converter/op_types.py | 5 +- nni/retiarii/converter/utils.py | 1 + nni/retiarii/converter/visualize.py | 4 +- nni/retiarii/execution/api.py | 5 +- nni/retiarii/execution/base.py | 8 +-- nni/retiarii/execution/cgo_engine.py | 33 ++++++------ nni/retiarii/execution/interface.py | 4 +- nni/retiarii/execution/listener.py | 6 +-- .../execution/logical_optimizer/interface.py | 6 +-- .../logical_optimizer/logical_plan.py | 27 +++++----- .../logical_optimizer/opt_batching.py | 10 ---- .../logical_optimizer/opt_dedup_input.py | 51 ++++++++++--------- .../logical_optimizer/opt_weight_sharing.py | 10 ---- nni/retiarii/experiment.py | 27 ++++++---- nni/retiarii/graph.py | 11 ++-- nni/retiarii/integration.py | 18 +++---- nni/retiarii/mutator.py | 7 ++- nni/retiarii/nn/pytorch/nn.py | 25 ++++++--- nni/retiarii/operation.py | 5 ++ nni/retiarii/operation_def/tf_op_def.py | 3 +- nni/retiarii/operation_def/torch_op_def.py | 2 + nni/retiarii/strategies/strategy.py | 6 ++- nni/retiarii/strategies/tpe_strategy.py | 14 +++-- nni/retiarii/trainer/interface.py | 3 +- nni/retiarii/trainer/pytorch/base.py | 17 +++---- nni/retiarii/trainer/pytorch/darts.py | 1 - nni/retiarii/trainer/pytorch/enas.py | 4 +- nni/retiarii/trainer/pytorch/random.py | 8 +-- nni/retiarii/trainer/pytorch/utils.py | 4 +- nni/retiarii/utils.py | 26 +++++++--- pipelines/fast-test.yml | 2 +- 34 files changed, 221 insertions(+), 199 deletions(-) delete mode 100644 nni/retiarii/execution/logical_optimizer/opt_batching.py delete mode 100644 nni/retiarii/execution/logical_optimizer/opt_weight_sharing.py diff --git a/nni/retiarii/codegen/pytorch.py b/nni/retiarii/codegen/pytorch.py index 71c37a3fe5..ad061cae9b 100644 --- a/nni/retiarii/codegen/pytorch.py +++ b/nni/retiarii/codegen/pytorch.py @@ -1,29 +1,28 @@ import logging -from typing import * +from typing import List from ..graph import IllegalGraphError, Edge, Graph, Node, Model -from ..operation import Operation, Cell _logger = logging.getLogger(__name__) - -def model_to_pytorch_script(model: Model, placement = None) -> str: +def model_to_pytorch_script(model: Model, placement=None) -> str: graphs = [] total_pkgs = set() for name, cell in model.graphs.items(): - import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement = placement) + import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement=placement) graphs.append(graph_code) total_pkgs.update(import_pkgs) pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs]) return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip() + def _sorted_incoming_edges(node: Node) -> List[Edge]: edges = [edge for edge in node.graph.edges if edge.tail is node] - _logger.info('sorted_incoming_edges: {}'.format(edges)) + _logger.info('sorted_incoming_edges: %s', str(edges)) if not edges: return [] - _logger.info(f'all tail_slots are None: {[edge.tail_slot for edge in edges]}') + _logger.info('all tail_slots are None: %s', str([edge.tail_slot for edge in edges])) if all(edge.tail_slot is None for edge in edges): return edges if all(isinstance(edge.tail_slot, int) for edge in edges): @@ -32,6 +31,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]: return edges raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name)) + def _format_inputs(node: Node) -> List[str]: edges = _sorted_incoming_edges(node) inputs = [] @@ -53,6 +53,7 @@ def _format_inputs(node: Node) -> List[str]: inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot)) return inputs + def _remove_prefix(names, graph_name): """ variables name (full name space) is too long, @@ -69,14 +70,14 @@ def _remove_prefix(names, graph_name): else: return names[len(graph_name):] if names.startswith(graph_name) else names -def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> str: + +def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str: nodes = graph.topo_sort() # handle module node and function node differently # only need to generate code for module here import_pkgs = set() node_codes = [] - placement_codes = [] for node in nodes: if node.operation: pkg_name = node.operation.get_import_pkg() diff --git a/nni/retiarii/converter/__init__.py b/nni/retiarii/converter/__init__.py index d9af675c12..e0fff09f2d 100644 --- a/nni/retiarii/converter/__init__.py +++ b/nni/retiarii/converter/__init__.py @@ -1,2 +1 @@ from .graph_gen import convert_to_graph -from .visualize import visualize_model \ No newline at end of file diff --git a/nni/retiarii/converter/graph_gen.py b/nni/retiarii/converter/graph_gen.py index 28d51378bd..3f16aabb70 100644 --- a/nni/retiarii/converter/graph_gen.py +++ b/nni/retiarii/converter/graph_gen.py @@ -1,14 +1,13 @@ -import json_tricks import logging import re -import torch -from ..graph import Graph, Node, Edge, Model -from ..operation import Cell, Operation -from ..nn.pytorch import Placeholder, LayerChoice, InputChoice +import torch -from .op_types import MODULE_EXCEPT_LIST, OpTypeName, BasicOpsPT -from .utils import build_full_name, _convert_name +from ..graph import Graph, Model, Node +from ..nn.pytorch import InputChoice, LayerChoice, Placeholder +from ..operation import Cell +from .op_types import MODULE_EXCEPT_LIST, BasicOpsPT, OpTypeName +from .utils import _convert_name, build_full_name _logger = logging.getLogger(__name__) @@ -16,6 +15,7 @@ global_graph_id = 0 modules_arg = None + def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False): """ Parameters @@ -76,6 +76,7 @@ def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap, new_node_input_idx += 1 + def create_prim_constant_node(ir_graph, node, module_name): global global_seq attrs = {} @@ -86,14 +87,17 @@ def create_prim_constant_node(ir_graph, node, module_name): node.kind(), attrs) return new_node + def handle_prim_attr_node(node): assert node.hasAttribute('name') attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()} return node.kind(), attrs + def _remove_mangle(module_type_str): return re.sub('\\.___torch_mangle_\\d+', '', module_type_str) + def remove_unconnected_nodes(ir_graph, targeted_type=None): """ Parameters @@ -122,6 +126,7 @@ def remove_unconnected_nodes(ir_graph, targeted_type=None): for hidden_node in to_removes: hidden_node.remove() + def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, ir_graph): """ Convert torch script node to our node ir, and build our graph ir @@ -156,7 +161,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i # TODO: add scope name ir_graph._add_input(_convert_name(_input.debugName())) - node_index = {} # graph node to graph ir node + node_index = {} # graph node to graph ir node # some node does not have output but it modifies a variable, for example aten::append # %17 : Tensor[] = aten::append(%out.1, %16) @@ -248,13 +253,14 @@ def handle_single_node(node): # therefore, we do this check for a module. example below: # %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self) # %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1) - assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(submodule_name, script_module._modules.keys()) + assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format( + submodule_name, script_module._modules.keys()) submodule_full_name = build_full_name(module_name, submodule_name) submodule_obj = getattr(module, submodule_name) subgraph, sub_m_attrs = convert_module(script_module._modules[submodule_name], - submodule_obj, - submodule_full_name, ir_model) + submodule_obj, + submodule_full_name, ir_model) else: # %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self) # %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8) @@ -271,7 +277,7 @@ def handle_single_node(node): predecessor_obj = getattr(module, predecessor_name) submodule_obj = getattr(predecessor_obj, submodule_name) subgraph, sub_m_attrs = convert_module(script_module._modules[predecessor_name]._modules[submodule_name], - submodule_obj, submodule_full_name, ir_model) + submodule_obj, submodule_full_name, ir_model) else: raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str())) @@ -329,7 +335,7 @@ def handle_single_node(node): node_type, attrs = handle_prim_attr_node(node) global_seq += 1 new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, global_seq), - node_type, attrs) + node_type, attrs) node_index[node] = new_node elif node.kind() == 'prim::min': print('zql: ', sm_graph) @@ -350,6 +356,7 @@ def handle_single_node(node): return node_index + def merge_aten_slices(ir_graph): """ if there is aten::slice node, merge the consecutive ones together. @@ -367,7 +374,7 @@ def merge_aten_slices(ir_graph): break if has_slice_node: assert head_slice_nodes - + for head_node in head_slice_nodes: slot = 0 new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice) @@ -391,11 +398,11 @@ def merge_aten_slices(ir_graph): slot += 4 ir_graph.hidden_nodes.remove(node) node = suc_node - + for edge in node.outgoing_edges: edge.head = new_slice_node ir_graph.hidden_nodes.remove(node) - + def refine_graph(ir_graph): """ @@ -408,13 +415,14 @@ def refine_graph(ir_graph): remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr') merge_aten_slices(ir_graph) + def _handle_layerchoice(module): global modules_arg m_attrs = {} candidates = module.candidate_ops choices = [] - for i, cand in enumerate(candidates): + for cand in candidates: assert id(cand) in modules_arg, 'id not exist: {}'.format(id(cand)) assert isinstance(modules_arg[id(cand)], dict) cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__ @@ -423,6 +431,7 @@ def _handle_layerchoice(module): m_attrs['label'] = module.label return m_attrs + def _handle_inputchoice(module): m_attrs = {} m_attrs['n_chosen'] = module.n_chosen @@ -430,6 +439,7 @@ def _handle_inputchoice(module): m_attrs['label'] = module.label return m_attrs + def convert_module(script_module, module, module_name, ir_model): """ Convert a module to its graph ir (i.e., Graph) along with its input arguments @@ -503,10 +513,11 @@ def convert_module(script_module, module, module_name, ir_model): # TODO: if we parse this module, it means we will create a graph (module class) # for this module. Then it is not necessary to record this module's arguments # return ir_graph, modules_arg[id(module)]. - # That is, we can refactor this part, to allow users to annotate which module + # That is, we can refactor this part, to allow users to annotate which module # should not be parsed further. return ir_graph, {} + def convert_to_graph(script_module, module, recorded_modules_arg): """ Convert module to our graph ir, i.e., build a ```Model``` type diff --git a/nni/retiarii/converter/op_types.py b/nni/retiarii/converter/op_types.py index 3fe4df9b94..a8240fa654 100644 --- a/nni/retiarii/converter/op_types.py +++ b/nni/retiarii/converter/op_types.py @@ -16,6 +16,7 @@ class OpTypeName(str, Enum): Placeholder = 'Placeholder' MergedSlice = 'MergedSlice' + # deal with aten op BasicOpsPT = { 'aten::mean': 'Mean', @@ -29,7 +30,7 @@ class OpTypeName(str, Enum): 'aten::size': 'Size', 'aten::view': 'View', 'aten::eq': 'Eq', - 'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4) + 'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4) } -BasicOpsTF = {} \ No newline at end of file +BasicOpsTF = {} diff --git a/nni/retiarii/converter/utils.py b/nni/retiarii/converter/utils.py index 9346b53f55..a4c392d617 100644 --- a/nni/retiarii/converter/utils.py +++ b/nni/retiarii/converter/utils.py @@ -6,6 +6,7 @@ def build_full_name(prefix, name, seq=None): else: return '{}__{}{}'.format(prefix, name, str(seq)) + def _convert_name(name: str) -> str: """ Convert the names using separator '.' to valid variable name in code diff --git a/nni/retiarii/converter/visualize.py b/nni/retiarii/converter/visualize.py index 3ff445a8da..31e29f3d6b 100644 --- a/nni/retiarii/converter/visualize.py +++ b/nni/retiarii/converter/visualize.py @@ -1,5 +1,6 @@ import graphviz + def convert_to_visualize(graph_ir, vgraph): for name, graph in graph_ir.items(): if name == '_training_config': @@ -33,7 +34,8 @@ def convert_to_visualize(graph_ir, vgraph): dst = cell_node[dst][0] subgraph.edge(src, dst) + def visualize_model(graph_ir): vgraph = graphviz.Digraph('G', filename='vgraph', format='jpg') convert_to_visualize(graph_ir, vgraph) - vgraph.render() \ No newline at end of file + vgraph.render() diff --git a/nni/retiarii/execution/api.py b/nni/retiarii/execution/api.py index 0d9580392a..c8d1d25253 100644 --- a/nni/retiarii/execution/api.py +++ b/nni/retiarii/execution/api.py @@ -1,12 +1,11 @@ import time import os -import importlib.util -from typing import * +from typing import List from ..graph import Model, ModelStatus from .base import BaseExecutionEngine from .cgo_engine import CGOExecutionEngine -from .interface import * +from .interface import AbstractExecutionEngine, WorkerInfo from .listener import DefaultListener _execution_engine = None diff --git a/nni/retiarii/execution/base.py b/nni/retiarii/execution/base.py index 695f96fbc9..51be1db360 100644 --- a/nni/retiarii/execution/base.py +++ b/nni/retiarii/execution/base.py @@ -1,5 +1,5 @@ import logging -from typing import * +from typing import Dict, Any, List from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo from .. import codegen, utils @@ -61,16 +61,16 @@ def register_graph_listener(self, listener: AbstractGraphListener) -> None: def _send_trial_callback(self, paramater: dict) -> None: for listener in self._listeners: - _logger.warning('resources: {}'.format(listener.resources)) + _logger.warning('resources: %s', listener.resources) if not listener.has_available_resource(): _logger.warning('There is no available resource, but trial is submitted.') listener.on_resource_used(1) - _logger.warning('on_resource_used: {}'.format(listener.resources)) + _logger.warning('on_resource_used: %s', listener.resources) def _request_trial_jobs_callback(self, num_trials: int) -> None: for listener in self._listeners: listener.on_resource_available(1 * num_trials) - _logger.warning('on_resource_available: {}'.format(listener.resources)) + _logger.warning('on_resource_available: %s', listener.resources) def _trial_end_callback(self, trial_id: int, success: bool) -> None: model = self._running_models[trial_id] diff --git a/nni/retiarii/execution/cgo_engine.py b/nni/retiarii/execution/cgo_engine.py index cad2d9de9f..ad13a0d594 100644 --- a/nni/retiarii/execution/cgo_engine.py +++ b/nni/retiarii/execution/cgo_engine.py @@ -1,6 +1,5 @@ import logging -import json -from typing import * +from typing import List, Dict, Tuple from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo from .. import codegen, utils @@ -12,8 +11,10 @@ from .base import BaseGraphData _logger = logging.getLogger(__name__) + + class CGOExecutionEngine(AbstractExecutionEngine): - def __init__(self, n_model_per_graph = 4) -> None: + def __init__(self, n_model_per_graph=4) -> None: self._listeners: List[AbstractGraphListener] = [] self._running_models: Dict[int, Model] = dict() self.logical_plan_counter = 0 @@ -30,38 +31,37 @@ def __init__(self, n_model_per_graph = 4) -> None: advisor.intermediate_metric_callback = self._intermediate_metric_callback advisor.final_metric_callback = self._final_metric_callback - def add_optimizer(self, opt): self._optimizers.append(opt) def submit_models(self, *models: List[Model]) -> None: - _logger.info(f'{len(models)} Models are submitted') + _logger.info('%d models are submitted', len(models)) logical = self._build_logical(models) - + for opt in self._optimizers: opt.convert(logical) - + phy_models_and_placements = self._assemble(logical) for model, placement, grouped_models in phy_models_and_placements: data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), - model.training_config.module, model.training_config.kwargs) + model.training_config.module, model.training_config.kwargs) for m in grouped_models: self._original_models[m.model_id] = m self._original_model_to_multi_model[m.model_id] = model self._running_models[send_trial(data.dump())] = model - + # for model in models: # data = BaseGraphData(codegen.model_to_pytorch_script(model), # model.config['trainer_module'], model.config['trainer_kwargs']) # self._running_models[send_trial(data.dump())] = model - - def _assemble(self, logical_plan : LogicalPlan) -> List[Tuple[Model, PhysicalDevice]]: + + def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, PhysicalDevice]]: # unique_models = set() # for node in logical_plan.graph.nodes: # if node.graph.model not in unique_models: # unique_models.add(node.graph.model) # return [m for m in unique_models] - grouped_models : List[Dict[Model, PhysicalDevice]] = AssemblePolicy().group(logical_plan) + grouped_models: List[Dict[Model, PhysicalDevice]] = AssemblePolicy().group(logical_plan) phy_models_and_placements = [] for multi_model in grouped_models: model, model_placement = logical_plan.assemble(multi_model) @@ -69,7 +69,7 @@ def _assemble(self, logical_plan : LogicalPlan) -> List[Tuple[Model, PhysicalDev return phy_models_and_placements def _build_logical(self, models: List[Model]) -> LogicalPlan: - logical_plan = LogicalPlan(id = self.logical_plan_counter) + logical_plan = LogicalPlan(plan_id=self.logical_plan_counter) for model in models: logical_plan.add_model(model) self.logical_plan_counter += 1 @@ -108,7 +108,7 @@ def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> N for model_id in merged_metrics: int_model_id = int(model_id) self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id]) - #model.intermediate_metrics.append(metrics) + # model.intermediate_metrics.append(metrics) for listener in self._listeners: listener.on_intermediate_metric(self._original_models[int_model_id], merged_metrics[model_id]) @@ -117,10 +117,9 @@ def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None: for model_id in merged_metrics: int_model_id = int(model_id) self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id]) - #model.intermediate_metrics.append(metrics) + # model.intermediate_metrics.append(metrics) for listener in self._listeners: listener.on_metric(self._original_models[int_model_id], merged_metrics[model_id]) - def query_available_resource(self) -> List[WorkerInfo]: raise NotImplementedError # move the method from listener to here? @@ -141,6 +140,7 @@ def trial_execute_graph(cls) -> None: trainer_instance = trainer_cls(model_cls(), graph_data.training_kwargs) trainer_instance.fit() + class AssemblePolicy: @staticmethod def group(logical_plan): @@ -148,4 +148,3 @@ def group(logical_plan): for idx, m in enumerate(logical_plan.models): group_model[m] = PhysicalDevice('server', f'cuda:{idx}') return [group_model] - \ No newline at end of file diff --git a/nni/retiarii/execution/interface.py b/nni/retiarii/execution/interface.py index 1b3ea9e330..d71abc0c98 100644 --- a/nni/retiarii/execution/interface.py +++ b/nni/retiarii/execution/interface.py @@ -1,5 +1,5 @@ -from abc import * -from typing import * +from abc import ABC, abstractmethod, abstractclassmethod +from typing import Any, NewType, List from ..graph import Model, MetricData diff --git a/nni/retiarii/execution/listener.py b/nni/retiarii/execution/listener.py index 56514b3f0b..d51de03915 100644 --- a/nni/retiarii/execution/listener.py +++ b/nni/retiarii/execution/listener.py @@ -1,7 +1,5 @@ -from typing import * - -from ..graph import * -from .interface import * +from ..graph import Model, ModelStatus +from .interface import MetricData, AbstractGraphListener class DefaultListener(AbstractGraphListener): diff --git a/nni/retiarii/execution/logical_optimizer/interface.py b/nni/retiarii/execution/logical_optimizer/interface.py index 0a7d39c130..2bd23a0d4c 100644 --- a/nni/retiarii/execution/logical_optimizer/interface.py +++ b/nni/retiarii/execution/logical_optimizer/interface.py @@ -1,8 +1,8 @@ -from abc import * -from typing import * +from abc import ABC from .logical_plan import LogicalPlan - + + class AbstractOptimizer(ABC): def __init__(self) -> None: pass diff --git a/nni/retiarii/execution/logical_optimizer/logical_plan.py b/nni/retiarii/execution/logical_optimizer/logical_plan.py index fd2dec9996..06ca3ef7c8 100644 --- a/nni/retiarii/execution/logical_optimizer/logical_plan.py +++ b/nni/retiarii/execution/logical_optimizer/logical_plan.py @@ -1,9 +1,8 @@ -from nni.retiarii.operation import Operation -from nni.retiarii.graph import Model, Graph, Edge, Node, Cell -from typing import * -import logging -from nni.retiarii.operation import _IOPseudoOperation import copy +from typing import Dict, Tuple, List, Any + +from ...graph import Cell, Edge, Graph, Model, Node +from ...operation import Operation, _IOPseudoOperation class PhysicalDevice: @@ -108,11 +107,11 @@ def _fork_to(self, graph: Graph): class LogicalPlan: - def __init__(self, id=0) -> None: + def __init__(self, plan_id=0) -> None: self.lp_model = Model(_internal=True) - self.id = id + self.id = plan_id self.logical_graph = LogicalGraph( - self.lp_model, id, name=f'{id}', _internal=True)._register() + self.lp_model, self.id, name=f'{self.id}', _internal=True)._register() self.lp_model._root_graph_name = self.logical_graph.name self.models = [] @@ -148,7 +147,7 @@ def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \ phy_model.training_config.kwargs['is_multi_model'] = True phy_model.training_config.kwargs['model_cls'] = phy_graph.name phy_model.training_config.kwargs['model_kwargs'] = [] - #FIXME: allow user to specify + # FIXME: allow user to specify phy_model.training_config.module = 'nni.retiarii.trainer.PyTorchMultiModelTrainer' # merge sub-graphs @@ -158,10 +157,9 @@ def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \ model.graphs[graph_name]._fork_to( phy_model, name_prefix=f'M_{model.model_id}_') - # When replace logical nodes, merge the training configs when # input/output nodes are replaced. - training_config_slot = {} # Model ID -> Slot ID + training_config_slot = {} # Model ID -> Slot ID input_slot_mapping = {} output_slot_mapping = {} # Replace all logical nodes to executable physical nodes @@ -230,7 +228,7 @@ def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \ to_node = copied_op[(edge.head, tail_placement)] else: to_operation = Operation.new( - 'ToDevice', {"device":tail_placement.device}) + 'ToDevice', {"device": tail_placement.device}) to_node = Node(phy_graph, phy_model._uid(), edge.head.name+"_to_"+edge.tail.name, to_operation)._register() Edge((edge.head, edge.head_slot), @@ -249,19 +247,18 @@ def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \ if edge.head in input_nodes: edge.head_slot = input_slot_mapping[edge.head] edge.head = phy_graph.input_node - # merge all output nodes into one with multiple slots output_nodes = [] for node in phy_graph.hidden_nodes: if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_outputs': output_nodes.append(node) - + for edge in phy_graph.edges: if edge.tail in output_nodes: edge.tail_slot = output_slot_mapping[edge.tail] edge.tail = phy_graph.output_node - + for node in input_nodes: node.remove() for node in output_nodes: diff --git a/nni/retiarii/execution/logical_optimizer/opt_batching.py b/nni/retiarii/execution/logical_optimizer/opt_batching.py deleted file mode 100644 index 55bd4fcb40..0000000000 --- a/nni/retiarii/execution/logical_optimizer/opt_batching.py +++ /dev/null @@ -1,10 +0,0 @@ -from .base_optimizer import BaseOptimizer -from .logical_plan import LogicalPlan - - -class BatchingOptimizer(BaseOptimizer): - def __init__(self) -> None: - pass - - def convert(self, logical_plan: LogicalPlan) -> None: - pass diff --git a/nni/retiarii/execution/logical_optimizer/opt_dedup_input.py b/nni/retiarii/execution/logical_optimizer/opt_dedup_input.py index 3064023ee1..4b50346f0b 100644 --- a/nni/retiarii/execution/logical_optimizer/opt_dedup_input.py +++ b/nni/retiarii/execution/logical_optimizer/opt_dedup_input.py @@ -1,32 +1,33 @@ -from .interface import AbstractOptimizer -from .logical_plan import LogicalPlan, AbstractLogicalNode, LogicalGraph, OriginNode, PhysicalDevice -from nni.retiarii import Graph, Node, Model -from typing import * -from nni.retiarii.operation import _IOPseudoOperation +from typing import List, Dict, Tuple +from ...graph import Graph, Model, Node +from .interface import AbstractOptimizer +from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan, + OriginNode, PhysicalDevice) _supported_training_modules = ['nni.retiarii.trainer.PyTorchImageClassificationTrainer'] + + class DedupInputNode(AbstractLogicalNode): - def __init__(self, logical_graph : LogicalGraph, id : int, \ - nodes_to_dedup : List[Node], _internal=False): - super().__init__(logical_graph, id, \ - "Dedup_"+nodes_to_dedup[0].name, \ - nodes_to_dedup[0].operation) - self.origin_nodes : List[OriginNode] = nodes_to_dedup.copy() - + def __init__(self, logical_graph: LogicalGraph, node_id: int, + nodes_to_dedup: List[Node], _internal=False): + super().__init__(logical_graph, node_id, + "Dedup_"+nodes_to_dedup[0].name, + nodes_to_dedup[0].operation) + self.origin_nodes: List[OriginNode] = nodes_to_dedup.copy() + def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) -> Tuple[Node, PhysicalDevice]: for node in self.origin_nodes: if node.original_graph.model in multi_model_placement: - new_node = Node(node.original_graph, node.id, \ - f'M_{node.original_graph.model.model_id}_{node.name}', \ - node.operation) + new_node = Node(node.original_graph, node.id, + f'M_{node.original_graph.model.model_id}_{node.name}', + node.operation) return new_node, multi_model_placement[node.original_graph.model] raise ValueError(f'DedupInputNode {self.name} does not contain nodes from multi_model') - + def _fork_to(self, graph: Graph): DedupInputNode(graph, self.id, self.origin_nodes)._register() - def __repr__(self) -> str: return f'DedupNode(id={self.id}, name={self.name}, \ len(nodes_to_dedup)={len(self.origin_nodes)}' @@ -35,6 +36,7 @@ def __repr__(self) -> str: class DedupInputOptimizer(AbstractOptimizer): def __init__(self) -> None: pass + def _check_deduplicate_by_node(self, root_node, node_to_check): if root_node == node_to_check: return True @@ -50,13 +52,12 @@ def _check_deduplicate_by_node(self, root_node, node_to_check): return False else: return False - - + def convert(self, logical_plan: LogicalPlan) -> None: nodes_to_skip = set() - while True: # repeat until the logical_graph converges + while True: # repeat until the logical_graph converges input_nodes = logical_plan.logical_graph.get_nodes_by_type("_inputs") - #_PseudoOperation(type_name="_inputs")) + # _PseudoOperation(type_name="_inputs")) root_node = None for node in input_nodes: if node in nodes_to_skip: @@ -64,21 +65,21 @@ def convert(self, logical_plan: LogicalPlan) -> None: root_node = node break if root_node == None: - break # end of convert + break # end of convert else: nodes_to_dedup = [] for node in input_nodes: if node in nodes_to_skip: continue if self._check_deduplicate_by_node(root_node, node): - nodes_to_dedup.append(node) + nodes_to_dedup.append(node) assert(len(nodes_to_dedup) >= 1) if len(nodes_to_dedup) == 1: assert(nodes_to_dedup[0] == root_node) nodes_to_skip.add(root_node) else: - dedup_node = DedupInputNode(logical_plan.logical_graph, \ - logical_plan.lp_model._uid(), nodes_to_dedup)._register() + dedup_node = DedupInputNode(logical_plan.logical_graph, + logical_plan.lp_model._uid(), nodes_to_dedup)._register() for edge in logical_plan.logical_graph.edges: if edge.head in nodes_to_dedup: edge.head = dedup_node diff --git a/nni/retiarii/execution/logical_optimizer/opt_weight_sharing.py b/nni/retiarii/execution/logical_optimizer/opt_weight_sharing.py deleted file mode 100644 index c3f9fa744d..0000000000 --- a/nni/retiarii/execution/logical_optimizer/opt_weight_sharing.py +++ /dev/null @@ -1,10 +0,0 @@ -from .base_optimizer import BaseOptimizer -from .logical_plan import LogicalPlan - - -class WeightSharingOptimizer(BaseOptimizer): - def __init__(self) -> None: - pass - - def convert(self, logical_plan: LogicalPlan) -> None: - pass diff --git a/nni/retiarii/experiment.py b/nni/retiarii/experiment.py index e15b263401..46af1729e2 100644 --- a/nni/retiarii/experiment.py +++ b/nni/retiarii/experiment.py @@ -1,27 +1,31 @@ -import dataclasses import logging import time from dataclasses import dataclass from pathlib import Path +from subprocess import Popen from threading import Thread -from typing import Any, List, Optional +from typing import Any, Optional -from ..experiment import Experiment, TrainingServiceConfig -from ..experiment import launcher, rest +from ..experiment import Experiment, TrainingServiceConfig, launcher, rest from ..experiment.config.base import ConfigBase, PathLike from ..experiment.config import util +from ..experiment.pipe import Pipe +from .graph import Model from .utils import get_records from .integration import RetiariiAdvisor -from .converter.graph_gen import convert_to_graph -from .mutator import LayerChoiceMutator, InputChoiceMutator +from .converter import convert_to_graph +from .mutator import Mutator, LayerChoiceMutator, InputChoiceMutator +from .trainer.interface import BaseTrainer +from .strategies.strategy import BaseStrategy _logger = logging.getLogger(__name__) + @dataclass(init=False) class RetiariiExeConfig(ConfigBase): experiment_name: Optional[str] = None - search_space: Any = '' # TODO: remove + search_space: Any = '' # TODO: remove trial_command: str = 'python3 -m nni.retiarii.trial_entry' trial_code_directory: PathLike = '.' trial_concurrency: int @@ -52,6 +56,7 @@ def _canonical_rules(self): def _validation_rules(self): return _validation_rules + _canonical_rules = { 'trial_code_directory': util.canonical_path, 'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None, @@ -70,8 +75,8 @@ def _validation_rules(self): class RetiariiExperiment(Experiment): - def __init__(self, base_model: 'nn.Module', trainer: 'BaseTrainer', - applied_mutators: List['Mutator'], strategy: 'BaseStrategy'): + def __init__(self, base_model: Model, trainer: BaseTrainer, + applied_mutators: Mutator, strategy: BaseStrategy): self.config: RetiariiExeConfig = None self.port: Optional[int] = None @@ -139,7 +144,7 @@ def start(self, config: RetiariiExeConfig, port: int = 8080, debug: bool = False debug Whether to start in debug mode. """ - # FIXME: + # FIXME: if debug: logging.getLogger('nni').setLevel(logging.DEBUG) @@ -189,4 +194,4 @@ def get_status(self) -> str: if self.port is None: raise RuntimeError('Experiment is not running') resp = rest.get(self.port, '/check-status') - return resp['status'] \ No newline at end of file + return resp['status'] diff --git a/nni/retiarii/graph.py b/nni/retiarii/graph.py index 0ddd185841..fee7336261 100644 --- a/nni/retiarii/graph.py +++ b/nni/retiarii/graph.py @@ -5,7 +5,6 @@ import copy from enum import Enum import json -from collections import defaultdict from typing import (Any, Dict, List, Optional, Tuple, Union, overload) from .operation import Cell, Operation, _IOPseudoOperation @@ -329,12 +328,12 @@ def get_nodes_by_type(self, operation_type: str) -> List['Node']: Returns nodes whose operation is specified typed. """ return [node for node in self.hidden_nodes if node.operation.type == operation_type] - - def get_node_by_id(self, id: int) -> Optional['Node']: + + def get_node_by_id(self, node_id: int) -> Optional['Node']: """ Returns the node which has specified name; or returns `None` if no node has this name. """ - found = [node for node in self.nodes if node.id == id] + found = [node for node in self.nodes if node.id == node_id] return found[0] if found else None def get_nodes_by_label(self, label: str) -> List['Node']: @@ -365,7 +364,8 @@ def topo_sort(self) -> List['Node']: curr_nodes.append(successor) for key in node_to_fanin: - assert node_to_fanin[key] == 0, '{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'.format(key, + assert node_to_fanin[key] == 0, '{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'.format( + key, node_to_fanin[key], key.predecessors[0], self.edges, @@ -587,6 +587,7 @@ def _dump(self) -> Any: ret['label'] = self.label return ret + class Edge: """ A tensor, or "data flow", between two nodes. diff --git a/nni/retiarii/integration.py b/nni/retiarii/integration.py index 19cd21f451..bcd76096a9 100644 --- a/nni/retiarii/integration.py +++ b/nni/retiarii/integration.py @@ -1,17 +1,14 @@ import logging -import threading -from typing import * +from typing import Any, Callable import json_tricks import nni from nni.runtime.msg_dispatcher_base import MsgDispatcherBase -from nni.runtime.protocol import send, CommandType +from nni.runtime.protocol import CommandType, send from nni.utils import MetricType -from . import utils from .graph import MetricData - _logger = logging.getLogger('nni.msg_dispatcher_base') @@ -44,6 +41,7 @@ class RetiariiAdvisor(MsgDispatcherBase): final_metric_callback """ + def __init__(self): super(RetiariiAdvisor, self).__init__() register_advisor(self) # register the current advisor as the "global only" advisor @@ -88,28 +86,28 @@ def send_trial(self, parameters): 'parameters': parameters, 'parameter_source': 'algorithm' } - _logger.info('New trial sent: {}'.format(new_trial)) + _logger.info('New trial sent: %s', new_trial) send(CommandType.NewTrialJob, json_tricks.dumps(new_trial)) if self.send_trial_callback is not None: self.send_trial_callback(parameters) # pylint: disable=not-callable return self.parameters_count def handle_request_trial_jobs(self, num_trials): - _logger.info('Request trial jobs: {}'.format(num_trials)) + _logger.info('Request trial jobs: %s', num_trials) if self.request_trial_jobs_callback is not None: self.request_trial_jobs_callback(num_trials) # pylint: disable=not-callable def handle_update_search_space(self, data): - _logger.info('Received search space: {}'.format(data)) + _logger.info('Received search space: %s', data) self.search_space = data def handle_trial_end(self, data): - _logger.info('Trial end: {}'.format(data)) # do nothing + _logger.info('Trial end: %s', data) self.trial_end_callback(json_tricks.loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable data['event'] == 'SUCCEEDED') def handle_report_metric_data(self, data): - _logger.info('Metric reported: {}'.format(data)) + _logger.info('Metric reported: %s', data) if data['type'] == MetricType.REQUEST_PARAMETER: raise ValueError('Request parameter not supported') elif data['type'] == MetricType.PERIODICAL: diff --git a/nni/retiarii/mutator.py b/nni/retiarii/mutator.py index e9b3d63873..a51bd3f1f8 100644 --- a/nni/retiarii/mutator.py +++ b/nni/retiarii/mutator.py @@ -13,6 +13,7 @@ class Sampler: """ Handles `Mutator.choice()` calls. """ + def choice(self, candidates: List[Choice], mutator: 'Mutator', model: Model, index: int) -> Choice: raise NotImplementedError() @@ -35,6 +36,7 @@ class Mutator: For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates. # Method names are open for discussion. """ + def __init__(self, sampler: Optional[Sampler] = None): self.sampler: Optional[Sampler] = sampler self._cur_model: Optional[Model] = None @@ -77,7 +79,6 @@ def dry_run(self, model: Model) -> List[List[Choice]]: self.sampler = sampler_backup return recorder.recorded_candidates, new_model - def mutate(self, model: Model) -> None: """ Abstract method to be implemented by subclass. @@ -105,6 +106,7 @@ def choice(self, candidates: List[Choice], *args) -> Choice: # the following is for inline mutation + class LayerChoiceMutator(Mutator): def __init__(self, node_name: str, candidates: List): super().__init__() @@ -118,6 +120,7 @@ def mutate(self, model): chosen_cand = self.candidates[chosen_index] target.update_operation(chosen_cand['type'], chosen_cand['parameters']) + class InputChoiceMutator(Mutator): def __init__(self, node_name: str, n_chosen: int): super().__init__() @@ -129,4 +132,4 @@ def mutate(self, model): candidates = [i for i in range(self.n_chosen)] chosen = self.choice(candidates) target.update_operation('__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs', - {'chosen': chosen}) + {'chosen': chosen}) diff --git a/nni/retiarii/nn/pytorch/nn.py b/nni/retiarii/nn/pytorch/nn.py index 8eca487fee..1a0629787c 100644 --- a/nni/retiarii/nn/pytorch/nn.py +++ b/nni/retiarii/nn/pytorch/nn.py @@ -1,8 +1,9 @@ import inspect import logging +from typing import Any, List + import torch import torch.nn as nn -from typing import (Any, Tuple, List, Optional) from ...utils import add_record @@ -10,7 +11,7 @@ __all__ = [ 'LayerChoice', 'InputChoice', 'Placeholder', - 'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict', + 'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict', 'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'GELU', 'Hardshrink', @@ -30,7 +31,7 @@ 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer', #'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d', #'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d', - #'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle', + #'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle', 'Flatten', 'Hardsigmoid', 'Hardswish' ] @@ -57,9 +58,10 @@ def __init__(self, n_candidates=None, choose_from=None, n_chosen=1, if n_candidates or choose_from or return_mask: _logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!') - def forward(self, candidate_inputs: List['Tensor']) -> 'Tensor': + def forward(self, candidate_inputs: List[torch.Tensor]) -> torch.Tensor: # fake return - return torch.tensor(candidate_inputs) + return torch.tensor(candidate_inputs) # pylint: disable=not-callable + class ValueChoice: """ @@ -67,6 +69,7 @@ class ValueChoice: when instantiating a pytorch module. TODO: can also be used in training approach """ + def __init__(self, candidate_values: List[Any]): self.candidate_values = candidate_values @@ -81,6 +84,7 @@ def __init__(self, label, related_info): def forward(self, x): return x + class ChosenInputs(nn.Module): def __init__(self, chosen: int): super().__init__() @@ -92,20 +96,24 @@ def forward(self, candidate_inputs): # the following are pytorch modules + class Module(nn.Module): def __init__(self): super(Module, self).__init__() + class Sequential(nn.Sequential): def __init__(self, *args): add_record(id(self), {}) super(Sequential, self).__init__(*args) + class ModuleList(nn.ModuleList): def __init__(self, *args): add_record(id(self), {}) super(ModuleList, self).__init__(*args) + def wrap_module(original_class): orig_init = original_class.__init__ argname_list = list(inspect.signature(original_class).parameters.keys()) @@ -115,14 +123,15 @@ def __init__(self, *args, **kws): full_args = {} full_args.update(kws) for i, arg in enumerate(args): - full_args[argname_list[i]] = args[i] + full_args[argname_list[i]] = arg add_record(id(self), full_args) - orig_init(self, *args, **kws) # Call the original __init__ + orig_init(self, *args, **kws) # Call the original __init__ - original_class.__init__ = __init__ # Set the class' __init__ to the new one + original_class.__init__ = __init__ # Set the class' __init__ to the new one return original_class + # TODO: support different versions of pytorch Identity = wrap_module(nn.Identity) Linear = wrap_module(nn.Linear) diff --git a/nni/retiarii/operation.py b/nni/retiarii/operation.py index 1ad7e01602..cd8b6d3a6f 100644 --- a/nni/retiarii/operation.py +++ b/nni/retiarii/operation.py @@ -4,12 +4,14 @@ __all__ = ['Operation', 'Cell'] + def _convert_name(name: str) -> str: """ Convert the names using separator '.' to valid variable name in code """ return name.replace('.', '__') + class Operation: """ Calculation logic of a graph node. @@ -152,6 +154,7 @@ def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str: else: raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}') + class TensorFlowOperation(Operation): def _to_class_name(self) -> str: return 'K.layers.' + self.type @@ -191,6 +194,7 @@ def forward(...): framework No real usage. Exists for compatibility with base class. """ + def __init__(self, cell_name: str, parameters: Dict[str, Any] = {}): self.type = '_cell' self.cell_name = cell_name @@ -207,6 +211,7 @@ class _IOPseudoOperation(Operation): The benefit is that users no longer need to verify `Node.operation is not None`, especially in static type checking. """ + def __init__(self, type_name: str, io_names: List = None): assert type_name.startswith('_') super(_IOPseudoOperation, self).__init__(type_name, {}, True) diff --git a/nni/retiarii/operation_def/tf_op_def.py b/nni/retiarii/operation_def/tf_op_def.py index d891a580ba..f4664e122e 100644 --- a/nni/retiarii/operation_def/tf_op_def.py +++ b/nni/retiarii/operation_def/tf_op_def.py @@ -1,7 +1,8 @@ from ..operation import TensorFlowOperation + class Conv2D(TensorFlowOperation): def __init__(self, type_name, parameters, _internal): if 'padding' not in parameters: parameters['padding'] = 'same' - super().__init__(type_name, parameters, _internal) \ No newline at end of file + super().__init__(type_name, parameters, _internal) diff --git a/nni/retiarii/operation_def/torch_op_def.py b/nni/retiarii/operation_def/torch_op_def.py index 92fef3d6e0..f691c11fe9 100644 --- a/nni/retiarii/operation_def/torch_op_def.py +++ b/nni/retiarii/operation_def/torch_op_def.py @@ -1,5 +1,6 @@ from ..operation import PyTorchOperation + class relu(PyTorchOperation): def to_init_code(self, field): return '' @@ -17,6 +18,7 @@ def to_forward_code(self, field, output, *inputs) -> str: assert len(inputs) == 1 return f'{output} = {inputs[0]}.view({inputs[0]}.size(0), -1)' + class ToDevice(PyTorchOperation): def to_init_code(self, field): return '' diff --git a/nni/retiarii/strategies/strategy.py b/nni/retiarii/strategies/strategy.py index d5041c985a..d89dba673b 100644 --- a/nni/retiarii/strategies/strategy.py +++ b/nni/retiarii/strategies/strategy.py @@ -1,8 +1,12 @@ import abc from typing import List +from ..graph import Model +from ..mutator import Mutator + + class BaseStrategy(abc.ABC): @abc.abstractmethod - def run(self, base_model: 'Model', applied_mutators: List['Mutator']) -> None: + def run(self, base_model: Model, applied_mutators: List[Mutator]) -> None: pass diff --git a/nni/retiarii/strategies/tpe_strategy.py b/nni/retiarii/strategies/tpe_strategy.py index 8a760eeca4..21fb814c84 100644 --- a/nni/retiarii/strategies/tpe_strategy.py +++ b/nni/retiarii/strategies/tpe_strategy.py @@ -1,16 +1,13 @@ -import json import logging -import random -import os -from .. import Model, submit_models, wait_models -from .. import Sampler +from .. import Sampler, submit_models, wait_models from .strategy import BaseStrategy from ...algorithms.hpo.hyperopt_tuner.hyperopt_tuner import HyperoptTuner _logger = logging.getLogger(__name__) + class TPESampler(Sampler): def __init__(self, optimize_mode='minimize'): self.tpe_tuner = HyperoptTuner('tpe', optimize_mode) @@ -37,6 +34,7 @@ def choice(self, candidates, mutator, model, index): self.index += 1 return chosen + class TPEStrategy(BaseStrategy): def __init__(self): self.tpe_sampler = TPESampler() @@ -55,7 +53,7 @@ def run(self, base_model, applied_mutators): while True: model = base_model _logger.info('apply mutators...') - _logger.info('mutators: {}'.format(applied_mutators)) + _logger.info('mutators: %s', str(applied_mutators)) self.tpe_sampler.generate_samples(self.model_id) for mutator in applied_mutators: _logger.info('mutate model...') @@ -66,6 +64,6 @@ def run(self, base_model, applied_mutators): wait_models(model) self.tpe_sampler.receive_result(self.model_id, model.metric) self.model_id += 1 - _logger.info('Strategy says:', model.metric) - except Exception as e: + _logger.info('Strategy says: %s', model.metric) + except Exception: _logger.error(logging.exception('message')) diff --git a/nni/retiarii/trainer/interface.py b/nni/retiarii/trainer/interface.py index e91f618a7b..1f3c108e38 100644 --- a/nni/retiarii/trainer/interface.py +++ b/nni/retiarii/trainer/interface.py @@ -1,6 +1,5 @@ import abc -import inspect -from typing import * +from typing import Any class BaseTrainer(abc.ABC): diff --git a/nni/retiarii/trainer/pytorch/base.py b/nni/retiarii/trainer/pytorch/base.py index 3124a9f70c..6d2156c0b5 100644 --- a/nni/retiarii/trainer/pytorch/base.py +++ b/nni/retiarii/trainer/pytorch/base.py @@ -1,5 +1,4 @@ -import abc -from typing import * +from typing import Any, List, Dict, Tuple import numpy as np import torch @@ -42,6 +41,7 @@ def get_default_transform(dataset: str) -> Any: # unsupported dataset, return None return None + @register_trainer() class PyTorchImageClassificationTrainer(BaseTrainer): """ @@ -94,7 +94,7 @@ def __init__(self, model, self._dataloader = DataLoader( self._dataset, **(dataloader_kwargs or {})) - def _accuracy(self, input, target): + def _accuracy(self, input, target): # pylint: disable=redefined-builtin _, predict = torch.max(input.data, 1) correct = predict.eq(target.data).cpu().sum().item() return correct / input.size(0) @@ -176,7 +176,7 @@ def __init__(self, multi_model, kwargs=[]): dataloader = DataLoader(dataset, **(dataloader_kwargs or {})) self._datasets.append(dataset) self._dataloaders.append(dataloader) - + if m['use_output']: optimizer_cls = m['optimizer_cls'] optimizer_kwargs = m['optimizer_kwargs'] @@ -186,7 +186,7 @@ def __init__(self, multi_model, kwargs=[]): name_prefix = '_'.join(name.split('_')[:2]) if m_header == name_prefix: one_model_params.append(param) - + optimizer = getattr(torch.optim, optimizer_cls)(one_model_params, **(optimizer_kwargs or {})) self._optimizers.append(optimizer) @@ -206,7 +206,7 @@ def _train(self): x, y = self.training_step_before_model(batch, batch_idx, f'cuda:{idx}') xs.append(x) ys.append(y) - + y_hats = self.multi_model(*xs) if len(ys) != len(xs): raise ValueError('len(ys) should be equal to len(xs)') @@ -230,13 +230,12 @@ def _train(self): if self.max_steps and batch_idx >= self.max_steps: return - def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]: x, y = self.training_step_before_model(batch, batch_idx) y_hat = self.model(x) return self.training_step_after_model(x, y, y_hat) - def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device = None): + def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None): x, y = batch if device: x, y = x.cuda(torch.device(device)), y.cuda(torch.device(device)) @@ -259,4 +258,4 @@ def validation_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], def validation_step_after_model(self, x, y, y_hat): acc = self._accuracy(y_hat, y) - return {'val_acc': acc} \ No newline at end of file + return {'val_acc': acc} diff --git a/nni/retiarii/trainer/pytorch/darts.py b/nni/retiarii/trainer/pytorch/darts.py index 9ff76698dd..b96fe91c12 100644 --- a/nni/retiarii/trainer/pytorch/darts.py +++ b/nni/retiarii/trainer/pytorch/darts.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn -from nni.nas.pytorch.mutables import LayerChoice from ..interface import BaseOneShotTrainer from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice diff --git a/nni/retiarii/trainer/pytorch/enas.py b/nni/retiarii/trainer/pytorch/enas.py index 20c593d0f6..7f03c1dd10 100644 --- a/nni/retiarii/trainer/pytorch/enas.py +++ b/nni/retiarii/trainer/pytorch/enas.py @@ -86,8 +86,8 @@ def __init__(self, fields, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.v_attn = nn.Linear(self.lstm_size, 1, bias=False) self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1) - self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), - requires_grad=False) # pylint: disable=not-callable + self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), # pylint: disable=not-callable + requires_grad=False) assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.' self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none') diff --git a/nni/retiarii/trainer/pytorch/random.py b/nni/retiarii/trainer/pytorch/random.py index a1242c9ab0..a82ddada10 100644 --- a/nni/retiarii/trainer/pytorch/random.py +++ b/nni/retiarii/trainer/pytorch/random.py @@ -16,7 +16,7 @@ def _get_mask(sampled, total): multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)] - return torch.tensor(multihot, dtype=torch.bool) + return torch.tensor(multihot, dtype=torch.bool) # pylint: disable=not-callable class PathSamplingLayerChoice(nn.Module): @@ -44,9 +44,9 @@ def __init__(self, layer_choice): def forward(self, *args, **kwargs): assert self.sampled is not None, 'At least one path needs to be sampled before fprop.' if isinstance(self.sampled, list): - return sum([getattr(self, self.op_names[i])(*args, **kwargs) for i in self.sampled]) + return sum([getattr(self, self.op_names[i])(*args, **kwargs) for i in self.sampled]) # pylint: disable=not-an-iterable else: - return getattr(self, self.op_names[self.sampled])(*args, **kwargs) + return getattr(self, self.op_names[self.sampled])(*args, **kwargs) # pylint: disable=invalid-sequence-index def __len__(self): return len(self.op_names) @@ -76,7 +76,7 @@ def __init__(self, input_choice): def forward(self, input_tensors): if isinstance(self.sampled, list): - return sum([input_tensors[t] for t in self.sampled]) + return sum([input_tensors[t] for t in self.sampled]) # pylint: disable=not-an-iterable else: return input_tensors[self.sampled] diff --git a/nni/retiarii/trainer/pytorch/utils.py b/nni/retiarii/trainer/pytorch/utils.py index c4340e9318..45e8b2f13c 100644 --- a/nni/retiarii/trainer/pytorch/utils.py +++ b/nni/retiarii/trainer/pytorch/utils.py @@ -123,13 +123,13 @@ def summary(self): return fmtstr.format(**self.__dict__) -def _replace_module_with_type(root_module, init_fn, type, modules): +def _replace_module_with_type(root_module, init_fn, type_name, modules): if modules is None: modules = [] def apply(m): for name, child in m.named_children(): - if isinstance(child, type): + if isinstance(child, type_name): setattr(m, name, init_fn(child)) modules.append((child.key, getattr(m, name))) else: diff --git a/nni/retiarii/utils.py b/nni/retiarii/utils.py index c846921b3c..9488991aef 100644 --- a/nni/retiarii/utils.py +++ b/nni/retiarii/utils.py @@ -1,19 +1,24 @@ -from collections import defaultdict import inspect +from collections import defaultdict +from typing import Any + -def import_(target: str, allow_none: bool = False) -> 'Any': +def import_(target: str, allow_none: bool = False) -> Any: if target is None: return None path, identifier = target.rsplit('.', 1) module = __import__(path, globals(), locals(), [identifier]) return getattr(module, identifier) + _records = {} + def get_records(): global _records return _records + def add_record(key, value): """ """ @@ -22,6 +27,7 @@ def add_record(key, value): assert key not in _records, '{} already in _records'.format(key) _records[key] = value + def _register_module(original_class): orig_init = original_class.__init__ argname_list = list(inspect.signature(original_class).parameters.keys()) @@ -31,14 +37,15 @@ def __init__(self, *args, **kws): full_args = {} full_args.update(kws) for i, arg in enumerate(args): - full_args[argname_list[i]] = args[i] + full_args[argname_list[i]] = arg add_record(id(self), full_args) - orig_init(self, *args, **kws) # Call the original __init__ + orig_init(self, *args, **kws) # Call the original __init__ - original_class.__init__ = __init__ # Set the class' __init__ to the new one + original_class.__init__ = __init__ # Set the class' __init__ to the new one return original_class + def register_module(): """ Register a module. @@ -68,14 +75,15 @@ def __init__(self, *args, **kws): if isinstance(args[i], Module): # ignore the base model object continue - full_args[argname_list[i]] = args[i] + full_args[argname_list[i]] = arg add_record(id(self), {'modulename': full_class_name, 'args': full_args}) - orig_init(self, *args, **kws) # Call the original __init__ + orig_init(self, *args, **kws) # Call the original __init__ - original_class.__init__ = __init__ # Set the class' __init__ to the new one + original_class.__init__ = __init__ # Set the class' __init__ to the new one return original_class + def register_trainer(): def _register(cls): m = _register_trainer( @@ -84,8 +92,10 @@ def _register(cls): return _register + _last_uid = defaultdict(int) + def uid(namespace: str = 'default') -> int: _last_uid[namespace] += 1 return _last_uid[namespace] diff --git a/pipelines/fast-test.yml b/pipelines/fast-test.yml index a5808cfa58..586a4d5547 100644 --- a/pipelines/fast-test.yml +++ b/pipelines/fast-test.yml @@ -41,7 +41,7 @@ jobs: python3 -m pip install --upgrade pygments python3 -m pip install --upgrade torch>=1.7.0+cpu torchvision>=0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html python3 -m pip install --upgrade tensorflow - python3 -m pip install --upgrade gym onnx peewee thop + python3 -m pip install --upgrade gym onnx peewee thop graphviz python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx sudo apt-get install swig -y python3 -m pip install -e .[SMAC,BOHB]