Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] Change op params from kwargs to dict, and fix ut/lint #3061

Merged
merged 4 commits into from
Nov 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 66 additions & 57 deletions nni/retiarii/graph.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
"""
Classes related to Graph IR, except `Operation`.
Model representation.
"""

from __future__ import annotations
import copy
import json
from enum import Enum
from typing import *
import json
from typing import (Any, Dict, List, Optional, Tuple, overload)

from .operation import Cell, Operation, _PseudoOperation


__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData']


MetricData = NewType('MetricData', Any)
MetricData = Any
"""
Graph metrics like loss, accuracy, etc.

Maybe we can assume this is a single float number for first iteration.
# Maybe we can assume this is a single float number for first iteration.
"""


Expand All @@ -36,15 +34,15 @@ class TrainingConfig:
Trainer keyword arguments
"""

def __init__(self, module: str, kwargs: Dict[str, any]):
def __init__(self, module: str, kwargs: Dict[str, Any]):
self.module = module
self.kwargs = kwargs

def __repr__(self):
return f'TrainingConfig(module={self.module}, kwargs={self.kwargs})'

@staticmethod
def _load(ir: Any) -> TrainingConfig:
def _load(ir: Any) -> 'TrainingConfig':
return TrainingConfig(ir['module'], ir.get('kwargs', {}))

def _dump(self) -> Any:
Expand All @@ -56,15 +54,14 @@ def _dump(self) -> Any:

class Model:
"""
Top-level structure of graph IR.

In execution engine's perspective, this is a trainable neural network model.
In mutator's perspective, this is a sandbox for a round of mutation.
Represents a neural network model.

Once a round of mutation starts, a sandbox is created and all mutating operations will happen inside.
When mutation is complete, the sandbox will be frozen to a trainable model.
Then the strategy will submit model to execution engine for training.
The model will record its metrics once trained.
During mutation, one `Model` object is created for each trainable snapshot.
For example, consider a mutator that insert a node at an edge for each iteration.
In one iteration, the mutator invokes 4 primitives: add node, remove edge, add edge to head, add edge to tail.
These 4 primitives operates in one `Model` object.
When they are all done the model will be set to "frozen" (trainable) status and be submitted to execution engine.
And then a new iteration starts, and a new `Model` object is created by forking last model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessarily "last model"


Attributes
----------
Expand Down Expand Up @@ -104,17 +101,17 @@ def __init__(self, _internal=False):
self.metric: Optional[MetricData] = None
self.intermediate_metrics: List[MetricData] = []

self._last_uid: int = 0
self._last_uid: int = 0 # FIXME: this should be global, not model-wise

def __repr__(self):
return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \
f'training_config={self.training_config}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics})'

@property
def root_graph(self) -> Graph:
def root_graph(self) -> 'Graph':
return self.graphs[self._root_graph_name]

def fork(self) -> Model:
def fork(self) -> 'Model':
"""
Create a new model which has same topology, names, and IDs to current one.

Expand All @@ -136,17 +133,17 @@ def _uid(self) -> int:
return self._last_uid

@staticmethod
def _load(ir: Any) -> Model:
def _load(ir: Any) -> 'Model':
model = Model(_internal=True)
for graph_name, graph_data in ir.items():
if graph_name != '_training_config':
Graph._load(model, graph_name, graph_data)._register()
#model.training_config = TrainingConfig._load(ir['_training_config'])
model.training_config = TrainingConfig._load(ir['_training_config'])
return model

def _dump(self) -> Any:
ret = {name: graph._dump() for name, graph in self.graphs.items()}
#ret['_training_config'] = self.training_config._dump()
ret['_training_config'] = self.training_config._dump()
return ret


Expand Down Expand Up @@ -227,41 +224,45 @@ def __repr__(self):
f'output_names={self.output_names}, num_hidden_nodes={len(self.hidden_nodes)}, num_edges={len(self.edges)})'

@property
def nodes(self) -> List[Node]:
def nodes(self) -> List['Node']:
return [self.input_node, self.output_node] + self.hidden_nodes

# mutation

def add_node(self, type: Union[Operation, str], **parameters) -> Node:
if isinstance(type, Operation):
assert not parameters
op = type
@overload
def add_node(self, operation: Operation) -> 'Node': ...
@overload
def add_node(self, type_name: str, parameters: Dict[str, Any] = {}) -> 'Node': ...

def add_node(self, operation_or_type, parameters={}):
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
op = Operation.new(type, **parameters)
op = Operation.new(operation_or_type, parameters)
return Node(self, self.model._uid(), None, op, _internal=True)._register()

# mutation
def add_edge(self, head: Tuple[Node, Optional[int]], tail: Tuple[Node, Optional[int]]) -> Edge:
def add_edge(self, head: Tuple['Node', Optional[int]], tail: Tuple['Node', Optional[int]]) -> 'Edge':
assert head[0].graph is self and tail[0].graph is self
return Edge(head, tail)._register()

def get_node_by_name(self, name: str) -> Optional[Node]:
def get_node_by_name(self, name: str) -> Optional['Node']:
"""
Returns the node which has specified name; or returns `None` if no node has this name.
"""
found = [node for node in self.nodes if node.name == name]
return found[0] if found else None

def get_nodes_by_type(self, operation_type: str) -> List[Node]:
def get_nodes_by_type(self, operation_type: str) -> List['Node']:
"""
Returns nodes whose operation is specified typed.
"""
return [node for node in self.hidden_nodes if node.operation.type == operation_type]

def topo_sort(self) -> List[Node]: # TODO
def topo_sort(self) -> List['Node']: # TODO
...

def fork(self) -> Graph:
def fork(self) -> 'Graph':
"""
Fork the model and returns corresponding graph in new model.
This shortcut might be helpful because many algorithms only cares about "stem" subgraph instead of whole model.
Expand All @@ -271,7 +272,7 @@ def fork(self) -> Graph:
def __eq__(self, other: object) -> bool:
return self is other

def _fork_to(self, model: Model) -> Graph:
def _fork_to(self, model: Model) -> 'Graph':
new_graph = Graph(model, self.id, self.name, _internal=True)._register()
new_graph.input_names = self.input_names
new_graph.output_names = self.output_names
Expand All @@ -288,7 +289,7 @@ def _fork_to(self, model: Model) -> Graph:

return new_graph

def _copy(self) -> Graph:
def _copy(self) -> 'Graph':
# Copy this graph inside the model.
# The new graph will have identical topology, but its nodes' name and ID will be different.
new_graph = Graph(self.model, self.model._uid(), _internal=True)._register()
Expand All @@ -308,12 +309,12 @@ def _copy(self) -> Graph:

return new_graph

def _register(self) -> Graph:
def _register(self) -> 'Graph':
self.model.graphs[self.name] = self
return self

@staticmethod
def _load(model: Model, name: str, ir: Any) -> Graph:
def _load(model: Model, name: str, ir: Any) -> 'Graph':
graph = Graph(model, model._uid(), name, _internal=True)
graph.input_names = ir.get('inputs')
graph.output_names = ir.get('outputs')
Expand Down Expand Up @@ -381,19 +382,19 @@ def __repr__(self):
return f'Node(id={self.id}, name={self.name}, operation={self.operation})'

@property
def predecessors(self) -> List[Node]:
def predecessors(self) -> List['Node']:
return sorted(set(edge.head for edge in self.incoming_edges), key=(lambda node: node.id))

@property
def successors(self) -> List[Node]:
def successors(self) -> List['Node']:
return sorted(set(edge.tail for edge in self.outgoing_edges), key=(lambda node: node.id))

@property
def incoming_edges(self) -> List[Edge]:
def incoming_edges(self) -> List['Edge']:
return [edge for edge in self.graph.edges if edge.tail is self]

@property
def outgoing_edges(self) -> List[Edge]:
def outgoing_edges(self) -> List['Edge']:
return [edge for edge in self.graph.edges if edge.head is self]

@property
Expand All @@ -403,12 +404,16 @@ def cell(self) -> Graph:

# mutation

def update_operation(self, type: Union[Operation, str], **parameters) -> None:
if isinstance(type, Operation):
assert not parameters
self.operation = type
@overload
def update_operation(self, operation: Operation) -> None: ...
@overload
def update_operation(self, type_name: str, parameters: Dict[str, Any] = {}) -> None: ...

def update_operation(self, operation_or_type, parameters={}):
if isinstance(operation_or_type, Operation):
self.operation = operation_or_type
else:
self.operation = Operation.new(type, **parameters)
self.operation = Operation.new(operation_or_type, parameters)

# mutation
def remove(self) -> None:
Expand All @@ -422,26 +427,29 @@ def specialize_cell(self) -> Graph:
Duplicate the cell template and let this node reference to newly created copy.
"""
new_cell = self.cell._copy()._register()
self.operation = Operation.new('_cell', cell=new_cell.name)
self.operation = Cell(new_cell.name)
return new_cell

def __eq__(self, other: object) -> bool:
return self is other

def _register(self) -> Node:
def _register(self) -> 'Node':
self.graph.hidden_nodes.append(self)
return self

@staticmethod
def _load(graph: Graph, name: str, ir: Any) -> Node:
ir = dict(ir)
if 'type' not in ir and 'cell' in ir:
ir['type'] = '_cell'
op = Operation.new(**ir)
def _load(graph: Graph, name: str, ir: Any) -> 'Node':
if ir['type'] == '_cell':
op = Cell(ir['cell'], ir.get('parameters', {}))
else:
op = Operation.new(ir['type'], ir.get('parameters', {}))
return Node(graph, graph.model._uid(), name, op)

def _dump(self) -> Any:
return {'type': self.operation.type, **self.operation.parameters}
ret = {'type': self.operation.type, 'parameters': self.operation.parameters}
if isinstance(self.operation, Cell):
ret['cell'] = self.operation.cell_name
return ret


class Edge:
Expand Down Expand Up @@ -499,14 +507,15 @@ def __repr__(self):
def remove(self) -> None:
self.graph.edges.remove(self)

def _register(self) -> Edge:
def _register(self) -> 'Edge':
self.graph.edges.append(self)
return self

@staticmethod
def _load(graph: Graph, ir: Any) -> Edge:
def _load(graph: Graph, ir: Any) -> 'Edge':
head = graph.get_node_by_name(ir['head'][0])
tail = graph.get_node_by_name(ir['tail'][0])
assert head is not None and tail is not None
return Edge((head, ir['head'][1]), (tail, ir['tail'][1]), _internal=True)

def _dump(self) -> Any:
Expand Down
19 changes: 11 additions & 8 deletions nni/retiarii/mutator.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
from __future__ import annotations
from typing import *
from .graph import *
from typing import (Any, Iterable, List, Optional)

from .graph import Model


__all__ = ['Sampler', 'Mutator']


Choice = NewType('Choice', Any)
Choice = Any


class Sampler:
"""
Handles `Mutator.choice()` calls.
"""
def choice(self, candidates: List[Choice], mutator: Mutator, model: Model, index: int) -> Choice:
def choice(self, candidates: List[Choice], mutator: 'Mutator', model: Model, index: int) -> Choice:
raise NotImplementedError()

def mutation_start(self, mutator: Mutator, model: Model) -> None:
def mutation_start(self, mutator: 'Mutator', model: Model) -> None:
pass

def mutation_end(self, mutator: Mutator, model: Model) -> None:
def mutation_end(self, mutator: 'Mutator', model: Model) -> None:
pass


Expand All @@ -44,11 +44,12 @@ def __init__(self, sampler: Optional[Sampler] = None):
self._cur_model: Optional[Model] = None
self._cur_choice_idx: Optional[int] = None

def bind_sampler(self, sampler: Sampler) -> Mutator:
def bind_sampler(self, sampler: Sampler) -> 'Mutator':
"""
Set the sampler which will handle `Mutator.choice` calls.
"""
self.sampler = sampler
return self

def apply(self, model: Model) -> Model:
"""
Expand All @@ -57,6 +58,7 @@ def apply(self, model: Model) -> Model:

The model will be copied before mutation and the original model will not be modified.
"""
assert self.sampler is not None
copy = model.fork()
self._cur_model = copy
self._cur_choice_idx = 0
Expand Down Expand Up @@ -93,6 +95,7 @@ def choice(self, candidates: Iterable[Choice]) -> Choice:
"""
Ask sampler to make a choice.
"""
assert self.sampler is not None and self._cur_model is not None and self._cur_choice_idx is not None
ret = self.sampler.choice(list(candidates), self, self._cur_model, self._cur_choice_idx)
self._cur_choice_idx += 1
return ret
Expand Down
Loading