Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
fix uid duplicate and add type hint alias for edge endpoint (#3188)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhe-lz authored Dec 14, 2020
1 parent 192a807 commit b3cdee8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 26 deletions.
41 changes: 17 additions & 24 deletions nni/retiarii/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@
from typing import (Any, Dict, List, Optional, Tuple, Union, overload)

from .operation import Cell, Operation, _IOPseudoOperation
from .utils import uid

__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData']


MetricData = Any
"""
Graph metrics like loss, accuracy, etc.
Type hint for graph metrics (loss, accuracy, etc).
"""

# Maybe we can assume this is a single float number for first iteration.
EdgeEndpoint = Tuple['Node', Optional[int]]
"""
Type hint for edge's endpoint. The int indicates nodes' order.
"""


Expand Down Expand Up @@ -88,12 +92,10 @@ class Model:
intermediate_metrics
Intermediate training metrics. If the model is not trained, it's an empty list.
"""
_cur_model_id = 0

def __init__(self, _internal=False):
assert _internal, '`Model()` is private, use `model.fork()` instead'
Model._cur_model_id += 1
self.model_id = Model._cur_model_id
self.model_id: int = uid('model')

self.status: ModelStatus = ModelStatus.Mutating

Expand All @@ -106,8 +108,6 @@ def __init__(self, _internal=False):
self.metric: Optional[MetricData] = None
self.intermediate_metrics: List[MetricData] = []

self._last_uid: int = 0 # FIXME: this should be global, not model-wise

def __repr__(self):
return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \
f'training_config={self.training_config}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics})'
Expand All @@ -130,13 +130,8 @@ def fork(self) -> 'Model':
new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()}
new_model.training_config = copy.deepcopy(self.training_config)
new_model.history = self.history + [self]
new_model._last_uid = self._last_uid
return new_model

def _uid(self) -> int:
self._last_uid += 1
return self._last_uid

@staticmethod
def _load(ir: Any) -> 'Model':
model = Model(_internal=True)
Expand Down Expand Up @@ -295,7 +290,7 @@ def add_node(self, name, operation_or_type, parameters={}):
op = operation_or_type
else:
op = Operation.new(operation_or_type, parameters, name)
return Node(self, self.model._uid(), name, op, _internal=True)._register()
return Node(self, uid(), name, op, _internal=True)._register()

@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ...
Expand All @@ -307,15 +302,15 @@ def insert_node_on_edge(self, edge, name, operation_or_type, parameters={}) -> '
op = operation_or_type
else:
op = Operation.new(operation_or_type, parameters, name)
new_node = Node(self, self.model._uid(), name, op, _internal=True)._register()
new_node = Node(self, uid(), name, op, _internal=True)._register()
# update edges
self.add_edge((edge.head, edge.head_slot), (new_node, None))
self.add_edge((new_node, None), (edge.tail, edge.tail_slot))
self.del_edge(edge)
return new_node

# mutation
def add_edge(self, head: Tuple['Node', Optional[int]], tail: Tuple['Node', Optional[int]]) -> 'Edge':
def add_edge(self, head: EdgeEndpoint, tail: EdgeEndpoint) -> 'Edge':
assert head[0].graph is self and tail[0].graph is self
return Edge(head, tail, _internal=True)._register()

Expand Down Expand Up @@ -414,7 +409,7 @@ def _fork_to(self, model: Model, name_prefix='') -> 'Graph':
def _copy(self) -> 'Graph':
# Copy this graph inside the model.
# The new graph will have identical topology, but its nodes' name and ID will be different.
new_graph = Graph(self.model, self.model._uid(), _internal=True)._register()
new_graph = Graph(self.model, uid(), _internal=True)._register()
new_graph.input_node.operation.io_names = self.input_node.operation.io_names
new_graph.output_node.operation.io_names = self.output_node.operation.io_names
new_graph.input_node.update_label(self.input_node.label)
Expand All @@ -423,7 +418,7 @@ def _copy(self) -> 'Graph':
id_to_new_node = {} # old node ID -> new node object

for old_node in self.hidden_nodes:
new_node = Node(new_graph, self.model._uid(), None, old_node.operation, _internal=True)._register()
new_node = Node(new_graph, uid(), None, old_node.operation, _internal=True)._register()
new_node.update_label(old_node.label)
id_to_new_node[old_node.id] = new_node

Expand All @@ -440,7 +435,7 @@ def _register(self) -> 'Graph':

@staticmethod
def _load(model: Model, name: str, ir: Any) -> 'Graph':
graph = Graph(model, model._uid(), name, _internal=True)
graph = Graph(model, uid(), name, _internal=True)
graph.input_node.operation.io_names = ir.get('inputs')
graph.output_node.operation.io_names = ir.get('outputs')
for node_name, node_data in ir['nodes'].items():
Expand Down Expand Up @@ -501,6 +496,8 @@ def __init__(self, graph, node_id, name, operation, _internal=False):
self.graph: Graph = graph
self.id: int = node_id
self.name: str = name or f'_generated_{node_id}'
# TODO: the operation is likely to be considered editable by end-user and it will be hard to debug
# maybe we should copy it here or make Operation class immutable, in next release
self.operation: Operation = operation
self.label: str = None

Expand Down Expand Up @@ -577,7 +574,7 @@ def _load(graph: Graph, name: str, ir: Any) -> 'Node':
op = Cell(ir['operation']['cell_name'], ir['operation'].get('parameters', {}))
else:
op = Operation.new(ir['operation']['type'], ir['operation'].get('parameters', {}))
node = Node(graph, graph.model._uid(), name, op)
node = Node(graph, uid(), name, op)
if 'label' in ir:
node.update_label(ir['label'])
return node
Expand Down Expand Up @@ -626,11 +623,7 @@ class Edge:
If the node does not care about order, this can be `-1`.
"""

def __init__(
self,
head: Tuple[Node, Optional[int]],
tail: Tuple[Node, Optional[int]],
_internal: bool = False):
def __init__(self, head: EdgeEndpoint, tail: EdgeEndpoint, _internal: bool = False):
assert _internal, '`Edge()` is private'
self.graph: Graph = head[0].graph
self.head: Node = head[0]
Expand Down
10 changes: 8 additions & 2 deletions nni/retiarii/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import inspect

def import_(target: str, allow_none: bool = False) -> 'Any':
Expand All @@ -7,7 +8,6 @@ def import_(target: str, allow_none: bool = False) -> 'Any':
module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier)


_records = {}

def get_records():
Expand Down Expand Up @@ -82,4 +82,10 @@ def _register(cls):
original_class=cls)
return m

return _register
return _register

_last_uid = defaultdict(int)

def uid(namespace: str = 'default') -> int:
_last_uid[namespace] += 1
return _last_uid[namespace]

0 comments on commit b3cdee8

Please sign in to comment.