diff --git a/src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py b/src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py index b6d0d9b7ba..ee110dd5d1 100644 --- a/src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py @@ -92,8 +92,8 @@ def _sample_layer_choice(self, mutable, idx, value, search_space_item): The list for corresponding search space. """ # doesn't support multihot for layer choice yet - onehot_list = [False] * mutable.length - assert 0 <= idx < mutable.length and search_space_item[idx] == value, \ + onehot_list = [False] * len(mutable) + assert 0 <= idx < len(mutable) and search_space_item[idx] == value, \ "Index '{}' in search space '{}' is not '{}'".format(idx, search_space_item, value) onehot_list[idx] = True return torch.tensor(onehot_list, dtype=torch.bool) # pylint: disable=not-callable diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py index 2aba20dd45..a4c3898a9b 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py @@ -61,7 +61,7 @@ def sample_final(self): if isinstance(mutable, LayerChoice): max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0) edges_max[mutable.key] = max_val - result[mutable.key] = F.one_hot(index, num_classes=mutable.length).view(-1).bool() + result[mutable.key] = F.one_hot(index, num_classes=len(mutable)).view(-1).bool() for mutable in self.mutables: if isinstance(mutable, InputChoice): if mutable.n_chosen is not None: diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py index 8cd107ec9d..7763622a58 100644 --- a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py @@ -86,15 +86,15 @@ def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, ce for mutable in self.mutables: if isinstance(mutable, LayerChoice): if self.max_layer_choice == 0: - self.max_layer_choice = mutable.length - assert self.max_layer_choice == mutable.length, \ + self.max_layer_choice = len(mutable) + assert self.max_layer_choice == len(mutable), \ "ENAS mutator requires all layer choice have the same number of candidates." # We are judging by keys and module types to add biases to layer choices. Needs refactor. if "reduce" in mutable.key: def is_conv(choice): return "conv" in str(type(choice)).lower() bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable - for choice in mutable.choices]) + for choice in mutable]) self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False) self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size) diff --git a/src/sdk/pynni/nni/nas/pytorch/mutables.py b/src/sdk/pynni/nni/nas/pytorch/mutables.py index 46d08fd756..5dbed524e0 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutables.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutables.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import logging +import warnings from collections import OrderedDict import torch.nn as nn @@ -140,9 +141,12 @@ class LayerChoice(Mutable): Attributes ---------- length : int - Number of ops to choose from. - names: list of str + Deprecated. Number of ops to choose from. ``len(layer_choice)`` is recommended. + names : list of str Names of candidates. + choices : list of Module + Deprecated. A list of all candidate modules in the layer choice module. + ``list(layer_choice)`` is recommended, which will serve the same purpose. Notes ----- @@ -156,30 +160,65 @@ class LayerChoice(Mutable): ("conv7x7", nn.Conv2d(7, 16, 128)) ])) + Elements in layer choice can be modified or deleted. Use ``del self.op_choice["conv5x5"]`` or + ``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet. """ def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None): super().__init__(key=key) - self.length = len(op_candidates) - self.choices = [] self.names = [] if isinstance(op_candidates, OrderedDict): for name, module in op_candidates.items(): assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \ "Please don't use a reserved name '{}' for your module.".format(name) self.add_module(name, module) - self.choices.append(module) self.names.append(name) elif isinstance(op_candidates, list): for i, module in enumerate(op_candidates): self.add_module(str(i), module) - self.choices.append(module) self.names.append(str(i)) else: raise TypeError("Unsupported op_candidates type: {}".format(type(op_candidates))) self.reduction = reduction self.return_mask = return_mask + def __getitem__(self, idx): + if isinstance(idx, str): + return self._modules[idx] + return list(self)[idx] + + def __setitem__(self, idx, module): + key = idx if isinstance(idx, str) else self.names[idx] + return setattr(self, key, module) + + def __delitem__(self, idx): + if isinstance(idx, slice): + for key in self.names[idx]: + delattr(self, key) + else: + if isinstance(idx, str): + key, idx = idx, self.names.index(idx) + else: + key = self.names[idx] + delattr(self, key) + del self.names[idx] + + @property + def length(self): + warnings.warn("layer_choice.length is deprecated. Use `len(layer_choice)` instead.", DeprecationWarning) + return len(self) + + def __len__(self): + return len(self.names) + + def __iter__(self): + return map(lambda name: self._modules[name], self.names) + + @property + def choices(self): + warnings.warn("layer_choice.choices is deprecated. Use `list(layer_choice)` instead.", DeprecationWarning) + return list(self) + def forward(self, *args, **kwargs): """ Returns diff --git a/src/sdk/pynni/nni/nas/pytorch/mutator.py b/src/sdk/pynni/nni/nas/pytorch/mutator.py index e461d50206..160a20de84 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutator.py @@ -150,16 +150,16 @@ def on_forward_layer_choice(self, mutable, *args, **kwargs): """ if self._connect_all: return self._all_connect_tensor_reduction(mutable.reduction, - [op(*args, **kwargs) for op in mutable.choices]), \ - torch.ones(mutable.length) + [op(*args, **kwargs) for op in mutable]), \ + torch.ones(len(mutable)) def _map_fn(op, args, kwargs): return op(*args, **kwargs) mask = self._get_decision(mutable) - assert len(mask) == len(mutable.choices), \ - "Invalid mask, expected {} to be of length {}.".format(mask, len(mutable.choices)) - out = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable.choices], mask) + assert len(mask) == len(mutable), \ + "Invalid mask, expected {} to be of length {}.".format(mask, len(mutable)) + out = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask) return self._tensor_reduction(mutable.reduction, out), mask def on_forward_input_choice(self, mutable, tensor_list): diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py index 47aedfa1b2..108557f30e 100644 --- a/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py @@ -32,7 +32,7 @@ def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}): for mutable in self.mutables: if isinstance(mutable, LayerChoice): - switches = self.switches.get(mutable.key, [True for j in range(mutable.length)]) + switches = self.switches.get(mutable.key, [True for j in range(len(mutable))]) choices = self.choices[mutable.key] operations_count = np.sum(switches) @@ -48,12 +48,12 @@ def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}): if isinstance(module, LayerChoice): switches = self.switches.get(module.key) choices = self.choices[module.key] - if len(module.choices) > len(choices): + if len(module) > len(choices): # from last to first, so that it won't effect previous indexes after removed one. for index in range(len(switches)-1, -1, -1): if switches[index] == False: - del(module.choices[index]) - module.length -= 1 + del module[index] + assert len(module) <= len(choices), "Failed to remove dropped choices." def sample_final(self): results = super().sample_final() diff --git a/src/sdk/pynni/nni/nas/pytorch/proxylessnas/mutator.py b/src/sdk/pynni/nni/nas/pytorch/proxylessnas/mutator.py index eb768e6fff..881a6b4403 100644 --- a/src/sdk/pynni/nni/nas/pytorch/proxylessnas/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/proxylessnas/mutator.py @@ -53,15 +53,15 @@ def __init__(self, mutable): A LayerChoice in user model """ super(MixedOp, self).__init__() - self.ap_path_alpha = nn.Parameter(torch.Tensor(mutable.length)) - self.ap_path_wb = nn.Parameter(torch.Tensor(mutable.length)) + self.ap_path_alpha = nn.Parameter(torch.Tensor(len(mutable))) + self.ap_path_wb = nn.Parameter(torch.Tensor(len(mutable))) self.ap_path_alpha.requires_grad = False self.ap_path_wb.requires_grad = False self.active_index = [0] self.inactive_index = None self.log_prob = None self.current_prob_over_ops = None - self.n_choices = mutable.length + self.n_choices = len(mutable) def get_ap_path_alpha(self): return self.ap_path_alpha @@ -120,8 +120,8 @@ def backward(_x, _output, grad_output): return binary_grads return backward output = ArchGradientFunction.apply( - x, self.ap_path_wb, run_function(mutable.key, mutable.choices, self.active_index[0]), - backward_function(mutable.key, mutable.choices, self.active_index[0], self.ap_path_wb)) + x, self.ap_path_wb, run_function(mutable.key, list(mutable), self.active_index[0]), + backward_function(mutable.key, list(mutable), self.active_index[0], self.ap_path_wb)) else: output = self.active_op(mutable)(x) return output @@ -164,7 +164,7 @@ def active_op(self, mutable): PyTorch module the chosen operation """ - return mutable.choices[self.active_index[0]] + return mutable[self.active_index[0]] @property def active_op_index(self): @@ -222,12 +222,12 @@ def binarize(self, mutable): sample = torch.multinomial(probs, 1)[0].item() self.active_index = [sample] self.inactive_index = [_i for _i in range(0, sample)] + \ - [_i for _i in range(sample + 1, len(mutable.choices))] + [_i for _i in range(sample + 1, len(mutable))] self.log_prob = torch.log(probs[sample]) self.current_prob_over_ops = probs self.ap_path_wb.data[sample] = 1.0 # avoid over-regularization - for choice in mutable.choices: + for choice in mutable: for _, param in choice.named_parameters(): param.grad = None @@ -430,8 +430,8 @@ def unused_modules_off(self): involved_index = mixed_op.active_index for i in range(mixed_op.n_choices): if i not in involved_index: - unused[i] = mutable.choices[i] - mutable.choices[i] = None + unused[i] = mutable[i] + mutable[i] = None self._unused_modules.append(unused) def unused_modules_back(self): @@ -442,7 +442,7 @@ def unused_modules_back(self): return for m, unused in zip(self.mutable_list, self._unused_modules): for i in unused: - m.choices[i] = unused[i] + m[i] = unused[i] self._unused_modules = None def arch_requires_grad(self): @@ -474,5 +474,5 @@ def sample_final(self): assert isinstance(mutable, LayerChoice) index, _ = mutable.registered_module.chosen_index # pylint: disable=not-callable - result[mutable.key] = F.one_hot(torch.tensor(index), num_classes=mutable.length).view(-1).bool() + result[mutable.key] = F.one_hot(torch.tensor(index), num_classes=len(mutable)).view(-1).bool() return result diff --git a/src/sdk/pynni/nni/nas/pytorch/random/mutator.py b/src/sdk/pynni/nni/nas/pytorch/random/mutator.py index 2a8cb25ef2..f302db56c0 100644 --- a/src/sdk/pynni/nni/nas/pytorch/random/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/random/mutator.py @@ -18,8 +18,8 @@ def sample_search(self): result = dict() for mutable in self.mutables: if isinstance(mutable, LayerChoice): - gen_index = torch.randint(high=mutable.length, size=(1, )) - result[mutable.key] = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool() + gen_index = torch.randint(high=len(mutable), size=(1, )) + result[mutable.key] = F.one_hot(gen_index, num_classes=len(mutable)).view(-1).bool() elif isinstance(mutable, InputChoice): if mutable.n_chosen is None: result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool() diff --git a/src/sdk/pynni/tests/models/pytorch_models/__init__.py b/src/sdk/pynni/tests/models/pytorch_models/__init__.py index 46d4482c86..363c7d3c9c 100644 --- a/src/sdk/pynni/tests/models/pytorch_models/__init__.py +++ b/src/sdk/pynni/tests/models/pytorch_models/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from .layer_choice_only import LayerChoiceOnlySearchSpace from .mutable_scope import SpaceWithMutableScope from .naive import NaiveSearchSpace from .nested import NestedSpace diff --git a/src/sdk/pynni/tests/models/pytorch_models/layer_choice_only.py b/src/sdk/pynni/tests/models/pytorch_models/layer_choice_only.py new file mode 100644 index 0000000000..c500bc9cdc --- /dev/null +++ b/src/sdk/pynni/tests/models/pytorch_models/layer_choice_only.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nni.nas.pytorch.mutables import LayerChoice + + +class LayerChoiceOnlySearchSpace(nn.Module): + def __init__(self, test_case): + super().__init__() + self.test_case = test_case + self.conv1 = LayerChoice([nn.Conv2d(3, 6, 3, padding=1), nn.Conv2d(3, 6, 5, padding=2)]) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = LayerChoice([nn.Conv2d(6, 16, 3, padding=1), nn.Conv2d(6, 16, 5, padding=2)], + return_mask=True) + self.conv3 = nn.Conv2d(16, 16, 1) + self.bn = nn.BatchNorm2d(16) + + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(16, 10) + + def forward(self, x): + bs = x.size(0) + + x = self.pool(F.relu(self.conv1(x))) + x0, mask = self.conv2(x) + self.test_case.assertEqual(mask.size(), torch.Size([2])) + x1 = F.relu(self.conv3(x0)) + + x = self.pool(self.bn(x1)) + self.test_case.assertEqual(mask.size(), torch.Size([2])) + + x = self.gap(x).view(bs, -1) + x = self.fc(x) + return x diff --git a/src/sdk/pynni/tests/test_nas.py b/src/sdk/pynni/tests/test_nas.py index 53b52541ad..5c1799a4a8 100644 --- a/src/sdk/pynni/tests/test_nas.py +++ b/src/sdk/pynni/tests/test_nas.py @@ -3,6 +3,7 @@ import importlib import os import sys +from collections import OrderedDict from unittest import TestCase, main import torch @@ -11,6 +12,7 @@ from nni.nas.pytorch.darts import DartsMutator from nni.nas.pytorch.enas import EnasMutator from nni.nas.pytorch.fixed import apply_fixed_architecture +from nni.nas.pytorch.mutables import LayerChoice from nni.nas.pytorch.random import RandomMutator from nni.nas.pytorch.utils import _reset_global_mutable_counting @@ -101,6 +103,43 @@ def test_classic_nas(self): get_and_apply_next_architecture(model) self.iterative_sample_and_forward(model) + def test_proxylessnas(self): + model = self.model_module.LayerChoiceOnlySearchSpace(self) + get_and_apply_next_architecture(model) + self.iterative_sample_and_forward(model) + + def test_layer_choice(self): + for i in range(2): + for j in range(2): + if j == 0: + # test number + layer_choice = LayerChoice([nn.Conv2d(3, 3, 3), nn.Conv2d(3, 5, 3), nn.Conv2d(3, 6, 3)]) + else: + # test ordered dict + layer_choice = LayerChoice(OrderedDict([ + ("conv1", nn.Conv2d(3, 3, 3)), + ("conv2", nn.Conv2d(3, 5, 3)), + ("conv3", nn.Conv2d(3, 6, 3)) + ])) + if i == 0: + # test modify + self.assertEqual(len(layer_choice.choices), 3) + layer_choice[1] = nn.Conv2d(3, 4, 3) + self.assertEqual(layer_choice[1].out_channels, 4) + self.assertEqual(len(layer_choice[0:2]), 2) + if j > 0: + layer_choice["conv3"] = nn.Conv2d(3, 7, 3) + self.assertEqual(layer_choice[-1].out_channels, 7) + if i == 1: + # test delete + del layer_choice[1] + self.assertEqual(len(layer_choice), 2) + self.assertEqual(len(list(layer_choice)), 2) + self.assertEqual(layer_choice.names, ["conv1", "conv3"] if j > 0 else ["0", "2"]) + if j > 0: + del layer_choice["conv1"] + self.assertEqual(len(layer_choice), 1) + if __name__ == '__main__': main()