diff --git a/docs/en_US/NAS/retiarii/ApiReference.rst b/docs/en_US/NAS/retiarii/ApiReference.rst index 43a86d8fa8..9d8cd03059 100644 --- a/docs/en_US/NAS/retiarii/ApiReference.rst +++ b/docs/en_US/NAS/retiarii/ApiReference.rst @@ -18,6 +18,12 @@ Inline Mutation APIs .. autoclass:: nni.retiarii.nn.pytorch.ChosenInputs :members: +.. autoclass:: nni.retiarii.nn.pytorch.Repeat + :members: + +.. autoclass:: nni.retiarii.nn.pytorch.Cell + :members: + Graph Mutation APIs ------------------- diff --git a/nni/retiarii/converter/graph_gen.py b/nni/retiarii/converter/graph_gen.py index 373cd9b69a..f8b06b887a 100644 --- a/nni/retiarii/converter/graph_gen.py +++ b/nni/retiarii/converter/graph_gen.py @@ -642,6 +642,16 @@ def convert_module(self, script_module, module, module_name, ir_model): ir_graph._register() + # add mutation signal for special modules + if original_type_name == OpTypeName.Repeat: + attrs = { + 'mutation': 'repeat', + 'label': module.label, + 'min_depth': module.min_depth, + 'max_depth': module.max_depth + } + return ir_graph, attrs + return ir_graph, {} diff --git a/nni/retiarii/converter/op_types.py b/nni/retiarii/converter/op_types.py index 0d59d9ea08..1a4ba5a42d 100644 --- a/nni/retiarii/converter/op_types.py +++ b/nni/retiarii/converter/op_types.py @@ -17,3 +17,5 @@ class OpTypeName(str, Enum): ValueChoice = 'ValueChoice' Placeholder = 'Placeholder' MergedSlice = 'MergedSlice' + Repeat = 'Repeat' + Cell = 'Cell' diff --git a/nni/retiarii/nn/pytorch/__init__.py b/nni/retiarii/nn/pytorch/__init__.py index dffb882777..5c392164b1 100644 --- a/nni/retiarii/nn/pytorch/__init__.py +++ b/nni/retiarii/nn/pytorch/__init__.py @@ -1,2 +1,3 @@ from .api import * +from .component import * from .nn import * diff --git a/nni/retiarii/nn/pytorch/api.py b/nni/retiarii/nn/pytorch/api.py index 2eef6ac627..69d12fb908 100644 --- a/nni/retiarii/nn/pytorch/api.py +++ b/nni/retiarii/nn/pytorch/api.py @@ -10,26 +10,12 @@ import torch.nn as nn from ...serializer import Translatable, basic_unit -from ...utils import uid, get_current_context +from .utils import generate_new_label, get_fixed_value __all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs'] -def _generate_new_label(label: Optional[str]): - if label is None: - return '_mutation_' + str(uid('mutation')) - return label - - -def _get_fixed_value(label: str): - ret = get_current_context('fixed') - try: - return ret[_generate_new_label(label)] - except KeyError: - raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}') - - class LayerChoice(nn.Module): """ Layer choice selects one of the ``candidates``, then apply it on inputs and return results. @@ -69,9 +55,9 @@ class LayerChoice(nn.Module): ``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet. """ - def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: str = None, **kwargs): + def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: Optional[str] = None, **kwargs): try: - chosen = _get_fixed_value(label) + chosen = get_fixed_value(label) if isinstance(candidates, list): return candidates[int(chosen)] else: @@ -79,7 +65,7 @@ def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label except AssertionError: return super().__new__(cls) - def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: str = None, **kwargs): + def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: Optional[str] = None, **kwargs): super(LayerChoice, self).__init__() if 'key' in kwargs: warnings.warn(f'"key" is deprecated. Assuming label.') @@ -89,7 +75,7 @@ def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], lab if 'reduction' in kwargs: warnings.warn(f'"reduction" is deprecated. Ignoring...') self.candidates = candidates - self._label = _generate_new_label(label) + self._label = generate_new_label(label) self.names = [] if isinstance(candidates, OrderedDict): @@ -187,13 +173,13 @@ class InputChoice(nn.Module): Identifier of the input choice. """ - def __new__(cls, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: str = None, **kwargs): + def __new__(cls, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: Optional[str] = None, **kwargs): try: - return ChosenInputs(_get_fixed_value(label), reduction=reduction) + return ChosenInputs(get_fixed_value(label), reduction=reduction) except AssertionError: return super().__new__(cls) - def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: str = None, **kwargs): + def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: Optional[str] = None, **kwargs): super(InputChoice, self).__init__() if 'key' in kwargs: warnings.warn(f'"key" is deprecated. Assuming label.') @@ -206,7 +192,7 @@ def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', self.n_chosen = n_chosen self.reduction = reduction assert self.reduction in ['mean', 'concat', 'sum', 'none'] - self._label = _generate_new_label(label) + self._label = generate_new_label(label) @property def key(self): @@ -295,16 +281,16 @@ def forward(self, x): Identifier of the value choice. """ - def __new__(cls, candidates: List[Any], label: str = None): + def __new__(cls, candidates: List[Any], label: Optional[str] = None): try: - return _get_fixed_value(label) + return get_fixed_value(label) except AssertionError: return super().__new__(cls) - def __init__(self, candidates: List[Any], label: str = None): + def __init__(self, candidates: List[Any], label: Optional[str] = None): super().__init__() self.candidates = candidates - self._label = _generate_new_label(label) + self._label = generate_new_label(label) self._accessor = [] @property diff --git a/nni/retiarii/nn/pytorch/component.py b/nni/retiarii/nn/pytorch/component.py new file mode 100644 index 0000000000..4ae5dc03bb --- /dev/null +++ b/nni/retiarii/nn/pytorch/component.py @@ -0,0 +1,147 @@ +import copy +from typing import Callable, List, Union, Tuple, Optional + +import torch +import torch.nn as nn + +from .api import LayerChoice, InputChoice +from .nn import ModuleList + +from .utils import generate_new_label, get_fixed_value + + +__all__ = ['Repeat', 'Cell'] + + +class Repeat(nn.Module): + """ + Repeat a block by a variable number of times. + + Parameters + ---------- + blocks : function, list of function, module or list of module + The block to be repeated. If not a list, it will be replicated into a list. + If a list, it should be of length ``max_depth``, the modules will be instantiated in order and a prefix will be taken. + If a function, it will be called to instantiate a module. Otherwise the module will be deep-copied. + depth : int or tuple of int + If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max), + meaning that the block will be repeated at least `min` times and at most `max` times. + """ + + def __new__(cls, blocks: Union[Callable[[], nn.Module], List[Callable[[], nn.Module]], nn.Module, List[nn.Module]], + depth: Union[int, Tuple[int, int]], label: Optional[str] = None): + try: + repeat = get_fixed_value(label) + return nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat)) + except AssertionError: + return super().__new__(cls) + + def __init__(self, + blocks: Union[Callable[[], nn.Module], List[Callable[[], nn.Module]], nn.Module, List[nn.Module]], + depth: Union[int, Tuple[int, int]], label: Optional[str] = None): + super().__init__() + self._label = generate_new_label(label) + self.min_depth = depth if isinstance(depth, int) else depth[0] + self.max_depth = depth if isinstance(depth, int) else depth[1] + assert self.max_depth >= self.min_depth > 0 + self.blocks = nn.ModuleList(self._replicate_and_instantiate(blocks, self.max_depth)) + + @property + def label(self): + return self._label + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + @staticmethod + def _replicate_and_instantiate(blocks, repeat): + if not isinstance(blocks, list): + if isinstance(blocks, nn.Module): + blocks = [blocks] + [copy.deepcopy(blocks) for _ in range(repeat - 1)] + else: + blocks = [blocks for _ in range(repeat)] + assert len(blocks) > 0 + assert repeat <= len(blocks), f'Not enough blocks to be used. {repeat} expected, only found {len(blocks)}.' + blocks = blocks[:repeat] + if not isinstance(blocks[0], nn.Module): + blocks = [b() for b in blocks] + return blocks + + +class Cell(nn.Module): + """ + Cell structure [1]_ [2]_ that is popularly used in NAS literature. + + A cell consists of multiple "nodes". Each node is a sum of multiple operators. Each operator is chosen from + ``op_candidates``, and takes one input from previous nodes and predecessors. Predecessor means the input of cell. + The output of cell is the concatenation of some of the nodes in the cell (currently all the nodes). + + Parameters + ---------- + op_candidates : function or list of module + A list of modules to choose from, or a function that returns a list of modules. + num_nodes : int + Number of nodes in the cell. + num_ops_per_node: int + Number of operators in each node. The output of each node is the sum of all operators in the node. Default: 1. + num_predecessors : int + Number of inputs of the cell. The input to forward should be a list of tensors. Default: 1. + merge_op : str + Currently only ``all`` is supported, which has slight difference with that described in reference. Default: all. + label : str + Identifier of the cell. Cell sharing the same label will semantically share the same choice. + + References + ---------- + .. [1] Barret Zoph, Quoc V. Le, "Neural Architecture Search with Reinforcement Learning". https://arxiv.org/abs/1611.01578 + .. [2] Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le, + "Learning Transferable Architectures for Scalable Image Recognition". https://arxiv.org/abs/1707.07012 + """ + + # TODO: + # Support loose end concat (shape inference on the following cells) + # How to dynamically create convolution with stride as the first node + + def __init__(self, + op_candidates: Union[Callable, List[nn.Module]], + num_nodes: int, + num_ops_per_node: int = 1, + num_predecessors: int = 1, + merge_op: str = 'all', + label: str = None): + super().__init__() + self._label = generate_new_label(label) + self.ops = ModuleList() + self.inputs = ModuleList() + self.num_nodes = num_nodes + self.num_ops_per_node = num_ops_per_node + self.num_predecessors = num_predecessors + for i in range(num_nodes): + self.ops.append(ModuleList()) + self.inputs.append(ModuleList()) + for k in range(num_ops_per_node): + if isinstance(op_candidates, list): + assert len(op_candidates) > 0 and isinstance(op_candidates[0], nn.Module) + ops = copy.deepcopy(op_candidates) + else: + ops = op_candidates() + self.ops[-1].append(LayerChoice(ops, label=f'{self.label}__op_{i}_{k}')) + self.inputs[-1].append(InputChoice(i + num_predecessors, 1, label=f'{self.label}/input_{i}_{k}')) + assert merge_op in ['all'] # TODO: loose_end + self.merge_op = merge_op + + @property + def label(self): + return self._label + + def forward(self, x: List[torch.Tensor]): + states = x + for ops, inps in zip(self.ops, self.inputs): + current_state = [] + for op, inp in zip(ops, inps): + current_state.append(op(inp(states))) + current_state = torch.sum(torch.stack(current_state), 0) + states.append(current_state) + return torch.cat(states[self.num_predecessors:], 1) diff --git a/nni/retiarii/nn/pytorch/mutator.py b/nni/retiarii/nn/pytorch/mutator.py index 8b2f790a69..6ef2ef19af 100644 --- a/nni/retiarii/nn/pytorch/mutator.py +++ b/nni/retiarii/nn/pytorch/mutator.py @@ -8,8 +8,9 @@ from ...mutator import Mutator from ...graph import Cell, Graph, Model, ModelStatus, Node -from ...utils import uid from .api import LayerChoice, InputChoice, ValueChoice, Placeholder +from .component import Repeat +from ...utils import uid class LayerChoiceMutator(Mutator): @@ -80,6 +81,42 @@ def mutate(self, model): target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen_value}) +class RepeatMutator(Mutator): + def __init__(self, nodes: List[Node]): + # nodes is a subgraph consisting of repeated blocks. + super().__init__() + self.nodes = nodes + + def _retrieve_chain_from_graph(self, graph: Graph) -> List[Node]: + u = graph.input_node + chain = [] + while u != graph.output_node: + if u != graph.input_node: + chain.append(u) + assert len(u.successors) == 1, f'This graph is an illegal chain. {u} has output {u.successor}.' + u = u.successors[0] + return chain + + def mutate(self, model): + min_depth = self.nodes[0].operation.parameters['min_depth'] + max_depth = self.nodes[0].operation.parameters['max_depth'] + if min_depth < max_depth: + chosen_depth = self.choice(list(range(min_depth, max_depth + 1))) + for node in self.nodes: + # the logic here is similar to layer choice. We find cell attached to each node. + target: Graph = model.graphs[node.operation.cell_name] + chain = self._retrieve_chain_from_graph(target) + for edge in chain[chosen_depth - 1].outgoing_edges: + edge.remove() + target.add_edge((chain[chosen_depth - 1], None), (target.output_node, None)) + for rm_node in chain[chosen_depth:]: + for edge in rm_node.outgoing_edges: + edge.remove() + rm_node.remove() + # to delete the unused parameters. + model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name)) + + def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: applied_mutators = [] @@ -120,6 +157,15 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: mutator = LayerChoiceMutator(node_list) applied_mutators.append(mutator) + repeat_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'repeat', + model.get_nodes_by_type('_cell'))) + for node_list in repeat_nodes: + assert _is_all_equal(map(lambda node: node.operation.parameters['max_depth'], node_list)) and \ + _is_all_equal(map(lambda node: node.operation.parameters['min_depth'], node_list)), \ + 'Repeat with the same label must have the same number of candidates.' + mutator = RepeatMutator(node_list) + applied_mutators.append(mutator) + if applied_mutators: return applied_mutators return None @@ -190,6 +236,11 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op if isinstance(module, ValueChoice): node = graph.add_node(name, 'ValueChoice', {'candidates': module.candidates}) node.label = module.label + if isinstance(module, Repeat) and module.min_depth <= module.max_depth: + node = graph.add_node(name, 'Repeat', { + 'candidates': list(range(module.min_depth, module.max_depth + 1)) + }) + node.label = module.label if isinstance(module, Placeholder): raise NotImplementedError('Placeholder is not supported in python execution mode.') diff --git a/nni/retiarii/nn/pytorch/utils.py b/nni/retiarii/nn/pytorch/utils.py new file mode 100644 index 0000000000..352348b997 --- /dev/null +++ b/nni/retiarii/nn/pytorch/utils.py @@ -0,0 +1,17 @@ +from typing import Optional + +from ...utils import uid, get_current_context + + +def generate_new_label(label: Optional[str]): + if label is None: + return '_mutation_' + str(uid('mutation')) + return label + + +def get_fixed_value(label: str): + ret = get_current_context('fixed') + try: + return ret[generate_new_label(label)] + except KeyError: + raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}') diff --git a/test/ut/retiarii/test_highlevel_apis.py b/test/ut/retiarii/test_highlevel_apis.py index c1b9bb0e3d..a3ff2b2d5d 100644 --- a/test/ut/retiarii/test_highlevel_apis.py +++ b/test/ut/retiarii/test_highlevel_apis.py @@ -379,7 +379,7 @@ def test_valuechoice_access_functional(self): class Net(nn.Module): def __init__(self): super().__init__() - self.dropout_rate = nn.ValueChoice([[0.,], [1.,]]) + self.dropout_rate = nn.ValueChoice([[0., ], [1., ]]) def forward(self, x): return F.dropout(x, self.dropout_rate()[0]) @@ -398,7 +398,7 @@ def test_valuechoice_access_functional_expression(self): class Net(nn.Module): def __init__(self): super().__init__() - self.dropout_rate = nn.ValueChoice([[1.05,], [1.1,]]) + self.dropout_rate = nn.ValueChoice([[1.05, ], [1.1, ]]) def forward(self, x): # if expression failed, the exception would be: @@ -414,6 +414,67 @@ def forward(self, x): self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3])) self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0) + def test_repeat(self): + class AddOne(nn.Module): + def forward(self, x): + return x + 1 + + @self.get_serializer() + class Net(nn.Module): + def __init__(self): + super().__init__() + self.block = nn.Repeat(AddOne(), (3, 5)) + + def forward(self, x): + return self.block(x) + + model, mutators = self._get_model_with_mutators(Net()) + self.assertEqual(len(mutators), 1) + mutator = mutators[0].bind_sampler(EnumerateSampler()) + model1 = mutator.apply(model) + model2 = mutator.apply(model) + model3 = mutator.apply(model) + self.assertTrue((self._get_converted_pytorch_model(model1)(torch.zeros(1, 16)) == 3).all()) + self.assertTrue((self._get_converted_pytorch_model(model2)(torch.zeros(1, 16)) == 4).all()) + self.assertTrue((self._get_converted_pytorch_model(model3)(torch.zeros(1, 16)) == 5).all()) + + def test_cell(self): + @self.get_serializer() + class Net(nn.Module): + def __init__(self): + super().__init__() + self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], + num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all') + + def forward(self, x, y): + return self.cell([x, y]) + + raw_model, mutators = self._get_model_with_mutators(Net()) + for _ in range(10): + sampler = EnumerateSampler() + model = raw_model + for mutator in mutators: + model = mutator.bind_sampler(sampler).apply(model) + self.assertTrue(self._get_converted_pytorch_model(model)( + torch.randn(1, 16), torch.randn(1, 16)).size() == torch.Size([1, 64])) + + @self.get_serializer() + class Net2(nn.Module): + def __init__(self): + super().__init__() + self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], num_nodes=4) + + def forward(self, x): + return self.cell([x]) + + raw_model, mutators = self._get_model_with_mutators(Net2()) + for _ in range(10): + sampler = EnumerateSampler() + model = raw_model + for mutator in mutators: + model = mutator.bind_sampler(sampler).apply(model) + self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64])) + class Python(GraphIR): def _get_converted_pytorch_model(self, model_ir):