diff --git a/docs/en_US/NAS/Overview.md b/docs/en_US/NAS/Overview.md
index adf1491928..9a6610bed2 100644
--- a/docs/en_US/NAS/Overview.md
+++ b/docs/en_US/NAS/Overview.md
@@ -54,6 +54,17 @@ Please refer to [here](NasGuide.md) for the usage of one-shot NAS algorithms.
One-shot NAS can be visualized with our visualization tool. Learn more details [here](./Visualization.md).
+
+## Search Space Zoo
+
+NNI provides some predefined search space which can be easily reused. By stacking the extracted cells, user can quickly reproduce those NAS models.
+
+Search Space Zoo contains the following NAS cells:
+
+* [DartsCell](./SearchSpaceZoo.md#DartsCell)
+* [ENAS micro](./SearchSpaceZoo.md#ENASMicroLayer)
+* [ENAS macro](./SearchSpaceZoo.md#ENASMacroLayer)
+
## Using NNI API to Write Your Search Space
The programming interface of designing and searching a model is often demanded in two scenarios.
diff --git a/docs/en_US/NAS/SearchSpaceZoo.md b/docs/en_US/NAS/SearchSpaceZoo.md
new file mode 100644
index 0000000000..5bdb599f72
--- /dev/null
+++ b/docs/en_US/NAS/SearchSpaceZoo.md
@@ -0,0 +1,175 @@
+# Search Space Zoo
+
+## DartsCell
+
+DartsCell is extracted from [CNN model](./DARTS.md) designed [here](https://github.com/microsoft/nni/tree/master/examples/nas/darts). A DartsCell is a directed acyclic graph containing an ordered sequence of N nodes and each node stands for a latent representation (e.g. feature map in a convolutional network). Directed edges from Node 1 to Node 2 are associated with some operations that transform Node 1 and the result is stored on Node 2. The [operations](#darts-predefined-operations) between nodes is predefined and unchangeable. One edge represents an operation that chosen from the predefined ones to be applied to the starting node of the edge. One cell contains two input nodes, a single output node, and other `n_node` nodes. The input nodes are defined as the cell outputs in the previous two layers. The output of the cell is obtained by applying a reduction operation (e.g. concatenation) to all the intermediate nodes. To make the search space continuous, the categorical choice of a particular operation is relaxed to a softmax over all possible operations. By adjusting the weight of softmax on every node, the operation with the highest probability is chosen to be part of the final structure. A CNN model can be formed by stacking several cells together, which builds a search space. Note that, in DARTS paper all cells in the model share the same structure.
+
+One structure in the Darts search space is shown below. Note that, NNI merges the last one of the four intermediate nodes and the output node.
+
+
+
+The predefined operations are shown in [references](#predefined-operations-darts).
+
+```eval_rst
+.. autoclass:: nni.nas.pytorch.search_space_zoo.DartsCell
+ :members:
+```
+
+### Example code
+
+[example code](https://github.com/microsoft/nni/tree/master/examples/nas/search_space_zoo/darts_example.py)
+
+```bash
+git clone https://github.com/Microsoft/nni.git
+cd nni/examples/nas/search_space_zoo
+# search the best structure
+python3 darts_example.py
+```
+
+
+
+### References
+
+All supported operations for Darts are listed below.
+
+* MaxPool / AvgPool
+ * MaxPool: Call `torch.nn.MaxPool2d`. This operation applies a 2D max pooling over all input channels. Its parameters `kernel_size=3` and `padding=1` are fixed. The pooling result will pass through a BatchNorm2d then return as the result.
+ * AvgPool: Call `torch.nn.AvgPool2d`. This operation applies a 2D average pooling over all input channels. Its parameters `kernel_size=3` and `padding=1` are fixed. The pooling result will pass through a BatchNorm2d then return as the result.
+
+ MaxPool / AvgPool with `kernel_size=3` and `padding=1` followed by BatchNorm2d
+ ```eval_rst
+ .. autoclass:: nni.nas.pytorch.search_space_zoo.darts_ops.PoolBN
+ ```
+* SkipConnect
+
+ There is no operation between two nodes. Call `torch.nn.Identity` to forward what it gets to the output.
+* Zero operation
+
+ There is no connection between two nodes.
+* DilConv3x3 / DilConv5x5
+
+ DilConv3x3: (Dilated) depthwise separable Conv. It's a 3x3 depthwise convolution with `C_in` groups, followed by a 1x1 pointwise convolution. It reduces the amount of parameters. Input is first passed through relu, then DilConv and finally batchNorm2d. **Note that the operation is not Dilated Convolution, but we follow the convention in NAS papers to name it DilConv.** 3x3 DilConv has parameters `kernel_size=3`, `padding=1` and 5x5 DilConv has parameters `kernel_size=5`, `padding=4`.
+ ```eval_rst
+ .. autoclass:: nni.nas.pytorch.search_space_zoo.darts_ops.DilConv
+ ```
+* SepConv3x3 / SepConv5x5
+
+ Composed of two DilConvs with fixed `kernel_size=3`, `padding=1` or `kernel_size=5`, `padding=2` sequentially.
+ ```eval_rst
+ .. autoclass:: nni.nas.pytorch.search_space_zoo.darts_ops.SepConv
+ ```
+
+## ENASMicroLayer
+
+This layer is extracted from the model designed [here](https://github.com/microsoft/nni/tree/master/examples/nas/enas). A model contains several blocks that share the same architecture. A block is made up of some normal layers and reduction layers, `ENASMicroLayer` is a unified implementation of the two types of layers. The only difference between the two layers is that reduction layers apply all operations with `stride=2`.
+
+ENAS Micro employs a DAG with N nodes in one cell, where the nodes represent local computations, and the edges represent the flow of information between the N nodes. One cell contains two input nodes and a single output node. The following nodes choose two previous nodes as input and apply two operations from [predefined ones](#predefined-operations-enas) then add them as the output of this node. For example, Node 4 chooses Node 1 and Node 3 as inputs then applies `MaxPool` and `AvgPool` on the inputs respectively, then adds and sums them as the output of Node 4. Nodes that are not served as input for any other node are viewed as the output of the layer. If there are multiple output nodes, the model will calculate the average of these nodes as the layer output.
+
+One structure in the ENAS micro search space is shown below.
+
+
+
+The predefined operations can be seen [here](#predefined-operations-enas).
+
+```eval_rst
+.. autoclass:: nni.nas.pytorch.search_space_zoo.ENASMicroLayer
+ :members:
+```
+
+The Reduction Layer is made up of two Conv operations followed by BatchNorm, each of them will output `C_out//2` channels and concat them in channels as the output. The Convolution has `kernel_size=1` and `stride=2`, and they perform alternate sampling on the input to reduce the resolution without loss of information. This layer is wrapped in `ENASMicroLayer`.
+
+### Example code
+
+[example code](https://github.com/microsoft/nni/tree/master/examples/nas/search_space_zoo/enas_micro_example.py)
+
+```bash
+git clone https://github.com/Microsoft/nni.git
+cd nni/examples/nas/search_space_zoo
+# search the best cell structure
+python3 enas_micro_example.py
+```
+
+
+
+### References
+
+All supported operations for ENAS micro search are listed below.
+
+* MaxPool / AvgPool
+ * MaxPool: Call `torch.nn.MaxPool2d`. This operation applies a 2D max pooling over all input channels followed by BatchNorm2d. Its parameters are fixed to `kernel_size=3`, `stride=1` and `padding=1`.
+ * AvgPool: Call `torch.nn.AvgPool2d`. This operation applies a 2D average pooling over all input channels followed by BatchNorm2d. Its parameters are fixed to `kernel_size=3`, `stride=1` and `padding=1`.
+ ```eval_rst
+ .. autoclass:: nni.nas.pytorch.search_space_zoo.enas_ops.Pool
+ ```
+
+* SepConv
+ * SepConvBN3x3: ReLU followed by a [DilConv](#DilConv) and BatchNorm. Convolution parameters are `kernel_size=3`, `stride=1` and `padding=1`.
+ * SepConvBN5x5: Do the same operation as the previous one but it has different kernel sizes and paddings, which is set to 5 and 2 respectively.
+
+ ```eval_rst
+ .. autoclass:: nni.nas.pytorch.search_space_zoo.enas_ops.SepConvBN
+ ```
+
+* SkipConnect
+
+ Call `torch.nn.Identity` to connect directly to the next cell.
+
+## ENASMacroLayer
+
+In Macro search, the controller makes two decisions for each layer: i) the [operation](#macro-operations) to perform on the result of the previous layer, ii) which the previous layer to connect to for SkipConnects. ENAS uses a controller to design the whole model architecture instead of one of its components. The output of operations is going to concat with the tensor of the chosen layer for SkipConnect. NNI provides [predefined operations](#macro-operations) for macro search, which are listed in [references](#macro-operations).
+
+Part of one structure in the ENAS macro search space is shown below.
+
+
+
+```eval_rst
+.. autoclass:: nni.nas.pytorch.search_space_zoo.ENASMacroLayer
+ :members:
+```
+
+To describe the whole search space, NNI provides a model, which is built by stacking the layers.
+
+```eval_rst
+.. autoclass:: nni.nas.pytorch.search_space_zoo.ENASMacroGeneralModel
+ :members:
+```
+
+### Example code
+
+[example code](https://github.com/microsoft/nni/tree/master/examples/nas/search_space_zoo/enas_macro_example.py)
+
+```bash
+git clone https://github.com/Microsoft/nni.git
+cd nni/examples/nas/search_space_zoo
+# search the best cell structure
+python3 enas_macro_example.py
+```
+
+
+
+### References
+
+All supported operations for ENAS macro search are listed below.
+
+* ConvBranch
+
+ All input first passes into a StdConv, which is made up of a 1x1Conv followed by BatchNorm2d and ReLU. Then the intermediate result goes through one of the operations listed below. The final result is calculated through a BatchNorm2d and ReLU as post-procedure.
+ * Separable Conv3x3: If `separable=True`, the cell will use [SepConv](#DilConv) instead of normal Conv operation. SepConv's `kernel_size=3`, `stride=1` and `padding=1`.
+ * Separable Conv5x5: SepConv's `kernel_size=5`, `stride=1` and `padding=2`.
+ * Normal Conv3x3: If `separable=False`, the cell will use a normal Conv operations with `kernel_size=3`, `stride=1` and `padding=1`.
+ * Normal Conv5x5: Conv's `kernel_size=5`, `stride=1` and `padding=2`.
+
+ ```eval_rst
+ .. autoclass:: nni.nas.pytorch.search_space_zoo.enas_ops.ConvBranch
+ ```
+* PoolBranch
+
+ All input first passes into a StdConv, which is made up of a 1x1Conv followed by BatchNorm2d and ReLU. Then the intermediate goes through pooling operation followed by BatchNorm.
+ * AvgPool: Call `torch.nn.AvgPool2d`. This operation applies a 2D average pooling over all input channels. Its parameters are fixed to `kernel_size=3`, `stride=1` and `padding=1`.
+ * MaxPool: Call `torch.nn.MaxPool2d`. This operation applies a 2D max pooling over all input channels. Its parameters are fixed to `kernel_size=3`, `stride=1` and `padding=1`.
+
+ ```eval_rst
+ .. autoclass:: nni.nas.pytorch.search_space_zoo.enas_ops.PoolBranch
+ ```
+
+
diff --git a/docs/en_US/nas.rst b/docs/en_US/nas.rst
index 280c4cad2c..27a60ff0f3 100644
--- a/docs/en_US/nas.rst
+++ b/docs/en_US/nas.rst
@@ -23,5 +23,6 @@ For details, please refer to the following tutorials:
One-shot NAS
Customize a NAS Algorithm
NAS Visualization
+ Search Space Zoo
NAS Benchmarks
API Reference
diff --git a/docs/img/NAS_Darts_cell.svg b/docs/img/NAS_Darts_cell.svg
new file mode 100644
index 0000000000..9dd61253cd
--- /dev/null
+++ b/docs/img/NAS_Darts_cell.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/img/NAS_ENAS_macro.svg b/docs/img/NAS_ENAS_macro.svg
new file mode 100644
index 0000000000..9897020e98
--- /dev/null
+++ b/docs/img/NAS_ENAS_macro.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/img/NAS_ENAS_micro.svg b/docs/img/NAS_ENAS_micro.svg
new file mode 100644
index 0000000000..afe9c79c43
--- /dev/null
+++ b/docs/img/NAS_ENAS_micro.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/examples/nas/search_space_zoo/darts_example.py b/examples/nas/search_space_zoo/darts_example.py
new file mode 100644
index 0000000000..3106d7038a
--- /dev/null
+++ b/examples/nas/search_space_zoo/darts_example.py
@@ -0,0 +1,53 @@
+# copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import logging
+import time
+from argparse import ArgumentParser
+
+import torch
+import torch.nn as nn
+
+import datasets
+from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
+from nni.nas.pytorch.darts import DartsTrainer
+from utils import accuracy
+
+from nni.nas.pytorch.search_space_zoo import DartsCell
+from darts_search_space import DartsStackedCells
+
+logger = logging.getLogger('nni')
+
+if __name__ == "__main__":
+ parser = ArgumentParser("darts")
+ parser.add_argument("--layers", default=8, type=int)
+ parser.add_argument("--batch-size", default=64, type=int)
+ parser.add_argument("--log-frequency", default=10, type=int)
+ parser.add_argument("--epochs", default=50, type=int)
+ parser.add_argument("--channels", default=16, type=int)
+ parser.add_argument("--unrolled", default=False, action="store_true")
+ parser.add_argument("--visualization", default=False, action="store_true")
+ args = parser.parse_args()
+
+ dataset_train, dataset_valid = datasets.get_dataset("cifar10")
+
+ model = DartsStackedCells(3, args.channels, 10, args.layers, DartsCell)
+ criterion = nn.CrossEntropyLoss()
+
+ optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
+
+ trainer = DartsTrainer(model,
+ loss=criterion,
+ metrics=lambda output, target: accuracy(output, target, topk=(1,)),
+ optimizer=optim,
+ num_epochs=args.epochs,
+ dataset_train=dataset_train,
+ dataset_valid=dataset_valid,
+ batch_size=args.batch_size,
+ log_frequency=args.log_frequency,
+ unrolled=args.unrolled,
+ callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
+ if args.visualization:
+ trainer.enable_visualization()
+ trainer.train()
diff --git a/examples/nas/search_space_zoo/darts_stack_cells.py b/examples/nas/search_space_zoo/darts_stack_cells.py
new file mode 100644
index 0000000000..7366b5000a
--- /dev/null
+++ b/examples/nas/search_space_zoo/darts_stack_cells.py
@@ -0,0 +1,83 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import torch.nn as nn
+import ops
+
+
+class DartsStackedCells(nn.Module):
+ """
+ builtin Darts Search Space
+ Compared to Darts example, DartsSearchSpace removes Auxiliary Head, which
+ is considered as a trick rather than part of model.
+
+ Attributes
+ ---
+ in_channels: int
+ the number of input channels
+ channels: int
+ the number of initial channels expected
+ n_classes: int
+ classes for final classification
+ n_layers: int
+ the number of cells contained in this network
+ factory_func: function
+ return a callable instance for demand cell structure.
+ user should pass in ``__init__`` of the cell class with required parameters (see nni.nas.DartsCell for detail)
+ n_nodes: int
+ the number of nodes contained in each cell
+ stem_multiplier: int
+ channels multiply coefficient when passing a cell
+ """
+
+ def __init__(self, in_channels, channels, n_classes, n_layers, factory_func, n_nodes=4,
+ stem_multiplier=3):
+ super().__init__()
+ self.in_channels = in_channels
+ self.channels = channels
+ self.n_classes = n_classes
+ self.n_layers = n_layers
+
+ c_cur = stem_multiplier * self.channels
+ self.stem = nn.Sequential(
+ nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
+ nn.BatchNorm2d(c_cur)
+ )
+
+ # for the first cell, stem is used for both s0 and s1
+ # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
+ channels_pp, channels_p, c_cur = c_cur, c_cur, channels
+
+ self.cells = nn.ModuleList()
+ reduction_p, reduction = False, False
+ for i in range(n_layers):
+ reduction_p, reduction = reduction, False
+ # Reduce featuremap size and double channels in 1/3 and 2/3 layer.
+ if i in [n_layers // 3, 2 * n_layers // 3]:
+ c_cur *= 2
+ reduction = True
+
+ cell = factory_func(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
+ self.cells.append(cell)
+ c_cur_out = c_cur * n_nodes
+ channels_pp, channels_p = channels_p, c_cur_out
+
+ self.gap = nn.AdaptiveAvgPool2d(1)
+ self.linear = nn.Linear(channels_p, n_classes)
+
+ def forward(self, x):
+ s0 = s1 = self.stem(x)
+
+ for cell in self.cells:
+ s0, s1 = s1, cell(s0, s1)
+
+ out = self.gap(s1)
+ out = out.view(out.size(0), -1) # flatten
+ logits = self.linear(out)
+
+ return logits
+
+ def drop_path_prob(self, p):
+ for module in self.modules():
+ if isinstance(module, ops.DropPath):
+ module.p = p
diff --git a/examples/nas/search_space_zoo/datasets.py b/examples/nas/search_space_zoo/datasets.py
new file mode 100644
index 0000000000..f19f5691a1
--- /dev/null
+++ b/examples/nas/search_space_zoo/datasets.py
@@ -0,0 +1,56 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import numpy as np
+import torch
+from torchvision import transforms
+from torchvision.datasets import CIFAR10
+
+
+class Cutout(object):
+ def __init__(self, length):
+ self.length = length
+
+ def __call__(self, img):
+ h, w = img.size(1), img.size(2)
+ mask = np.ones((h, w), np.float32)
+ y = np.random.randint(h)
+ x = np.random.randint(w)
+
+ y1 = np.clip(y - self.length // 2, 0, h)
+ y2 = np.clip(y + self.length // 2, 0, h)
+ x1 = np.clip(x - self.length // 2, 0, w)
+ x2 = np.clip(x + self.length // 2, 0, w)
+
+ mask[y1: y2, x1: x2] = 0.
+ mask = torch.from_numpy(mask)
+ mask = mask.expand_as(img)
+ img *= mask
+
+ return img
+
+
+def get_dataset(cls, cutout_length=0):
+ MEAN = [0.49139968, 0.48215827, 0.44653124]
+ STD = [0.24703233, 0.24348505, 0.26158768]
+ transf = [
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip()
+ ]
+ normalize = [
+ transforms.ToTensor(),
+ transforms.Normalize(MEAN, STD)
+ ]
+ cutout = []
+ if cutout_length > 0:
+ cutout.append(Cutout(cutout_length))
+
+ train_transform = transforms.Compose(transf + normalize + cutout)
+ valid_transform = transforms.Compose(normalize)
+
+ if cls == "cifar10":
+ dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
+ dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform)
+ else:
+ raise NotImplementedError
+ return dataset_train, dataset_valid
diff --git a/examples/nas/search_space_zoo/enas_macro_example.py b/examples/nas/search_space_zoo/enas_macro_example.py
new file mode 100644
index 0000000000..3688a61a16
--- /dev/null
+++ b/examples/nas/search_space_zoo/enas_macro_example.py
@@ -0,0 +1,89 @@
+import torch
+import logging
+import torch.nn as nn
+import torch.nn.functional as F
+
+from argparse import ArgumentParser
+from torchvision import transforms
+from torchvision.datasets import CIFAR10
+
+from nni.nas.pytorch import mutables
+from nni.nas.pytorch import enas
+from utils import accuracy, reward_accuracy
+from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
+ LRSchedulerCallback)
+from nni.nas.pytorch.search_space_zoo import ENASMacroLayer
+from nni.nas.pytorch.search_space_zoo import ENASMacroGeneralModel
+
+logger = logging.getLogger('nni')
+
+
+def get_dataset(cls):
+ MEAN = [0.49139968, 0.48215827, 0.44653124]
+ STD = [0.24703233, 0.24348505, 0.26158768]
+ transf = [
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip()
+ ]
+ normalize = [
+ transforms.ToTensor(),
+ transforms.Normalize(MEAN, STD)
+ ]
+
+ train_transform = transforms.Compose(transf + normalize)
+ valid_transform = transforms.Compose(normalize)
+
+ if cls == "cifar10":
+ dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
+ dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform)
+ else:
+ raise NotImplementedError
+ return dataset_train, dataset_valid
+
+
+class FactorizedReduce(nn.Module):
+ def __init__(self, C_in, C_out, affine=False):
+ super().__init__()
+ self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
+ self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
+ self.bn = nn.BatchNorm2d(C_out, affine=affine)
+
+ def forward(self, x):
+ out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
+ out = self.bn(out)
+ return out
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser("enas")
+ parser.add_argument("--batch-size", default=128, type=int)
+ parser.add_argument("--log-frequency", default=10, type=int)
+ # parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
+ parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)")
+ parser.add_argument("--visualization", default=False, action="store_true")
+ args = parser.parse_args()
+
+ dataset_train, dataset_valid = get_dataset("cifar10")
+ model = ENASMacroGeneralModel()
+ num_epochs = args.epochs or 310
+ mutator = None
+
+ criterion = nn.CrossEntropyLoss()
+ optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001)
+
+ trainer = enas.EnasTrainer(model,
+ loss=criterion,
+ metrics=accuracy,
+ reward_function=reward_accuracy,
+ optimizer=optimizer,
+ callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
+ batch_size=args.batch_size,
+ num_epochs=num_epochs,
+ dataset_train=dataset_train,
+ dataset_valid=dataset_valid,
+ log_frequency=args.log_frequency,
+ mutator=mutator)
+ if args.visualization:
+ trainer.enable_visualization()
+ trainer.train()
diff --git a/examples/nas/search_space_zoo/enas_micro_example.py b/examples/nas/search_space_zoo/enas_micro_example.py
new file mode 100644
index 0000000000..385a19024d
--- /dev/null
+++ b/examples/nas/search_space_zoo/enas_micro_example.py
@@ -0,0 +1,131 @@
+import torch
+import logging
+import torch.nn as nn
+import torch.nn.functional as F
+
+from argparse import ArgumentParser
+from torchvision import transforms
+from torchvision.datasets import CIFAR10
+
+from nni.nas.pytorch import enas
+from utils import accuracy, reward_accuracy
+from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
+ LRSchedulerCallback)
+
+from nni.nas.pytorch.search_space_zoo import ENASMicroLayer
+
+logger = logging.getLogger('nni')
+
+
+def get_dataset(cls):
+ MEAN = [0.49139968, 0.48215827, 0.44653124]
+ STD = [0.24703233, 0.24348505, 0.26158768]
+ transf = [
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip()
+ ]
+ normalize = [
+ transforms.ToTensor(),
+ transforms.Normalize(MEAN, STD)
+ ]
+
+ train_transform = transforms.Compose(transf + normalize)
+ valid_transform = transforms.Compose(normalize)
+
+ if cls == "cifar10":
+ dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
+ dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform)
+ else:
+ raise NotImplementedError
+ return dataset_train, dataset_valid
+
+
+class MicroNetwork(nn.Module):
+ def __init__(self, num_layers=2, num_nodes=5, out_channels=24, in_channels=3, num_classes=10,
+ dropout_rate=0.0):
+ super().__init__()
+ self.num_layers = num_layers
+
+ self.stem = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels * 3, 3, 1, 1, bias=False),
+ nn.BatchNorm2d(out_channels * 3)
+ )
+
+ pool_distance = self.num_layers // 3
+ pool_layers = [pool_distance, 2 * pool_distance + 1]
+ self.dropout = nn.Dropout(dropout_rate)
+
+ self.layers = nn.ModuleList()
+ c_pp = c_p = out_channels * 3
+ c_cur = out_channels
+ for layer_id in range(self.num_layers + 2):
+ reduction = False
+ if layer_id in pool_layers:
+ c_cur, reduction = c_p * 2, True
+ self.layers.append(ENASMicroLayer(self.layers, num_nodes, c_pp, c_p, c_cur, reduction))
+ if reduction:
+ c_pp = c_p = c_cur
+ c_pp, c_p = c_p, c_cur
+
+ self.gap = nn.AdaptiveAvgPool2d(1)
+ self.dense = nn.Linear(c_cur, num_classes)
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight)
+
+ def forward(self, x):
+ bs = x.size(0)
+ prev = cur = self.stem(x)
+ # aux_logits = None
+
+ for layer in self.layers:
+ prev, cur = layer(prev, cur)
+
+ cur = self.gap(F.relu(cur)).view(bs, -1)
+ cur = self.dropout(cur)
+ logits = self.dense(cur)
+
+ # if aux_logits is not None:
+ # return logits, aux_logits
+ return logits
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser("enas")
+ parser.add_argument("--batch-size", default=128, type=int)
+ parser.add_argument("--log-frequency", default=10, type=int)
+ # parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
+ parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)")
+ parser.add_argument("--visualization", default=False, action="store_true")
+ args = parser.parse_args()
+
+ dataset_train, dataset_valid = get_dataset("cifar10")
+
+ model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1)
+ num_epochs = args.epochs or 150
+ mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True)
+
+ criterion = nn.CrossEntropyLoss()
+ optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001)
+
+ trainer = enas.EnasTrainer(model,
+ loss=criterion,
+ metrics=accuracy,
+ reward_function=reward_accuracy,
+ optimizer=optimizer,
+ callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
+ batch_size=args.batch_size,
+ num_epochs=num_epochs,
+ dataset_train=dataset_train,
+ dataset_valid=dataset_valid,
+ log_frequency=args.log_frequency,
+ mutator=mutator)
+ if args.visualization:
+ trainer.enable_visualization()
+ trainer.train()
+
diff --git a/examples/nas/search_space_zoo/utils.py b/examples/nas/search_space_zoo/utils.py
new file mode 100644
index 0000000000..f680db479f
--- /dev/null
+++ b/examples/nas/search_space_zoo/utils.py
@@ -0,0 +1,30 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import torch
+
+
+def accuracy(output, target, topk=(1,)):
+ """ Computes the precision@k for the specified values of k """
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ # one-hot case
+ if target.ndimension() > 1:
+ target = target.max(1)[1]
+
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = dict()
+ for k in topk:
+ correct_k = correct[:k].view(-1).float().sum(0)
+ res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
+ return res
+
+
+def reward_accuracy(output, target, topk=(1,)):
+ batch_size = target.size(0)
+ _, predicted = torch.max(output.data, 1)
+ return (predicted == target).sum().item() / batch_size
diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py b/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
index 8078425ce1..1a22790fb9 100644
--- a/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
+++ b/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
@@ -2,4 +2,4 @@
# Licensed under the MIT license.
from .mutator import DartsMutator
-from .trainer import DartsTrainer
\ No newline at end of file
+from .trainer import DartsTrainer
diff --git a/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/__init__.py b/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/__init__.py
new file mode 100644
index 0000000000..59bb3b78d1
--- /dev/null
+++ b/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/__init__.py
@@ -0,0 +1,4 @@
+from .darts_cell import DartsCell
+from .enas_cell import ENASMicroLayer
+from .enas_cell import ENASMacroLayer
+from .enas_cell import ENASMacroGeneralModel
diff --git a/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/darts_cell.py b/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/darts_cell.py
new file mode 100644
index 0000000000..53fca5940c
--- /dev/null
+++ b/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/darts_cell.py
@@ -0,0 +1,112 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+from nni.nas.pytorch import mutables
+
+from .darts_ops import PoolBN, SepConv, DilConv, FactorizedReduce, DropPath, StdConv
+
+
+class Node(nn.Module):
+ def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
+ """
+ builtin Darts Node structure
+
+ Parameters
+ ---
+ node_id: str
+ num_prev_nodes: int
+ the number of previous nodes in this cell
+ channels: int
+ output channels
+ num_downsample_connect: int
+ downsample the input node if this cell is reduction cell
+ """
+ super().__init__()
+ self.ops = nn.ModuleList()
+ choice_keys = []
+ for i in range(num_prev_nodes):
+ stride = 2 if i < num_downsample_connect else 1
+ choice_keys.append("{}_p{}".format(node_id, i))
+ self.ops.append(
+ mutables.LayerChoice(OrderedDict([
+ ("maxpool", PoolBN('max', channels, 3, stride, 1, affine=False)),
+ ("avgpool", PoolBN('avg', channels, 3, stride, 1, affine=False)),
+ ("skipconnect",
+ nn.Identity() if stride == 1 else FactorizedReduce(channels, channels, affine=False)),
+ ("sepconv3x3", SepConv(channels, channels, 3, stride, 1, affine=False)),
+ ("sepconv5x5", SepConv(channels, channels, 5, stride, 2, affine=False)),
+ ("dilconv3x3", DilConv(channels, channels, 3, stride, 2, 2, affine=False)),
+ ("dilconv5x5", DilConv(channels, channels, 5, stride, 4, 2, affine=False))
+ ]), key=choice_keys[-1]))
+ self.drop_path = DropPath()
+ self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
+
+ def forward(self, prev_nodes):
+ assert len(self.ops) == len(prev_nodes)
+ out = [op(node) for op, node in zip(self.ops, prev_nodes)]
+ out = [self.drop_path(o) if o is not None else None for o in out]
+ return self.input_switch(out)
+
+
+class DartsCell(nn.Module):
+ """
+ Builtin Darts Cell structure. There are ``n_nodes`` nodes in one cell, in which the first two nodes' values are
+ fixed to the results of previous previous cell and previous cell respectively. One node will connect all
+ the nodes after with predefined operations in a mutable way. The last node accepts five inputs from nodes
+ before and it concats all inputs in channels as the output of the current cell, and the number of output
+ channels is ``n_nodes`` times ``channels``.
+
+ Parameters
+ ---
+ n_nodes: int
+ the number of nodes contained in this cell
+ channels_pp: int
+ the number of previous previous cell's output channels
+ channels_p: int
+ the number of previous cell's output channels
+ channels: int
+ the number of output channels for each node
+ reduction_p: bool
+ Is previous cell a reduction cell
+ reduction: bool
+ is current cell a reduction cell
+ """
+ def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
+ super().__init__()
+ self.reduction = reduction
+ self.n_nodes = n_nodes
+
+ # If previous cell is reduction cell, current input size does not match with
+ # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
+ if reduction_p:
+ self.preproc0 = FactorizedReduce(channels_pp, channels, affine=False)
+ else:
+ self.preproc0 = StdConv(channels_pp, channels, 1, 1, 0, affine=False)
+ self.preproc1 = StdConv(channels_p, channels, 1, 1, 0, affine=False)
+
+ # generate dag
+ self.mutable_ops = nn.ModuleList()
+ for depth in range(2, self.n_nodes + 2):
+ self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth),
+ depth, channels, 2 if reduction else 0))
+
+ def forward(self, pprev, prev):
+ """
+ Parameters
+ ---
+ pprev: torch.Tensor
+ the output of the previous previous layer
+ prev: torch.Tensor
+ the output of the previous layer
+ """
+ tensors = [self.preproc0(pprev), self.preproc1(prev)]
+ for node in self.mutable_ops:
+ cur_tensor = node(tensors)
+ tensors.append(cur_tensor)
+
+ output = torch.cat(tensors[2:], dim=1)
+ return output
diff --git a/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/darts_ops.py b/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/darts_ops.py
new file mode 100644
index 0000000000..ce5410cfb4
--- /dev/null
+++ b/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/darts_ops.py
@@ -0,0 +1,196 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import torch
+import torch.nn as nn
+
+
+class DropPath(nn.Module):
+ def __init__(self, p=0.):
+ """
+ Drop path with probability.
+
+ Parameters
+ ----------
+ p : float
+ Probability of an path to be zeroed.
+ """
+ super().__init__()
+ self.p = p
+
+ def forward(self, x):
+ if self.training and self.p > 0.:
+ keep_prob = 1. - self.p
+ # per data point mask
+ mask = torch.zeros((x.size(0), 1, 1, 1), device=x.device).bernoulli_(keep_prob)
+ return x / keep_prob * mask
+
+ return x
+
+
+class PoolBN(nn.Module):
+ """
+ AvgPool or MaxPool with BN. ``pool_type`` must be ``max`` or ``avg``.
+
+ Parameters
+ ---
+ pool_type: str
+ choose operation
+ C: int
+ number of channels
+ kernal_size: int
+ size of the convolving kernel
+ stride: int
+ stride of the convolution
+ padding: int
+ zero-padding added to both sides of the input
+ affine: bool
+ is using affine in BatchNorm
+ """
+
+ def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
+ super().__init__()
+ if pool_type.lower() == 'max':
+ self.pool = nn.MaxPool2d(kernel_size, stride, padding)
+ elif pool_type.lower() == 'avg':
+ self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
+ else:
+ raise ValueError()
+
+ self.bn = nn.BatchNorm2d(C, affine=affine)
+
+ def forward(self, x):
+ out = self.pool(x)
+ out = self.bn(out)
+ return out
+
+
+class StdConv(nn.Sequential):
+ """
+ Standard conv: ReLU - Conv - BN
+
+ Parameters
+ ---
+ C_in: int
+ the number of input channels
+ C_out: int
+ the number of output channels
+ kernel_size: int
+ size of the convolution kernel
+ padding:
+ zero-padding added to both sides of the input
+ affine: bool
+ is using affine in BatchNorm
+ """
+
+ def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
+ super().__init__()
+ self.net = nn.Sequential
+ for idx, ops in enumerate((nn.ReLU(), nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),
+ nn.BatchNorm2d(C_out, affine=affine))):
+ self.add_module(str(idx), ops)
+
+
+class FacConv(nn.Module):
+ """
+ Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
+ """
+
+ def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.ReLU(),
+ nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
+ nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
+ nn.BatchNorm2d(C_out, affine=affine)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class DilConv(nn.Module):
+ """
+ (Dilated) depthwise separable conv.
+ ReLU - (Dilated) depthwise separable - Pointwise - BN.
+ If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
+
+ Parameters
+ ---
+ C_in: int
+ the number of input channels
+ C_out: int
+ the number of output channels
+ kernal_size:
+ size of the convolving kernel
+ padding:
+ zero-padding added to both sides of the input
+ dilation: int
+ spacing between kernel elements.
+ affine: bool
+ is using affine in BatchNorm
+ """
+
+ def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.ReLU(),
+ nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
+ bias=False),
+ nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(C_out, affine=affine)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class SepConv(nn.Module):
+ """
+ Depthwise separable conv.
+ DilConv(dilation=1) * 2.
+
+ Parameters
+ ---
+ C_in: int
+ the number of input channels
+ C_out: int
+ the number of output channels
+ kernal_size:
+ size of the convolving kernel
+ padding:
+ zero-padding added to both sides of the input
+ dilation: int
+ spacing between kernel elements.
+ affine: bool
+ is using affine in BatchNorm
+ """
+
+ def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
+ super().__init__()
+ self.net = nn.Sequential(
+ DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
+ DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class FactorizedReduce(nn.Module):
+ """
+ Reduce feature map size by factorized pointwise (stride=2).
+ """
+
+ def __init__(self, C_in, C_out, affine=True):
+ super().__init__()
+ self.relu = nn.ReLU()
+ self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
+ self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
+ self.bn = nn.BatchNorm2d(C_out, affine=affine)
+
+ def forward(self, x):
+ x = self.relu(x)
+ out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
+ out = self.bn(out)
+ return out
diff --git a/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py b/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py
new file mode 100644
index 0000000000..ef3de84385
--- /dev/null
+++ b/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py
@@ -0,0 +1,256 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from nni.nas.pytorch import mutables
+from .enas_ops import FactorizedReduce, StdConv, SepConvBN, Pool, ConvBranch, PoolBranch
+
+
+class Cell(nn.Module):
+ def __init__(self, cell_name, prev_labels, channels):
+ super().__init__()
+ self.input_choice = mutables.InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True,
+ key=cell_name + "_input")
+ self.op_choice = mutables.LayerChoice([
+ SepConvBN(channels, channels, 3, 1),
+ SepConvBN(channels, channels, 5, 2),
+ Pool("avg", 3, 1, 1),
+ Pool("max", 3, 1, 1),
+ nn.Identity()
+ ], key=cell_name + "_op")
+
+ def forward(self, prev_layers):
+ chosen_input, chosen_mask = self.input_choice(prev_layers)
+ cell_out = self.op_choice(chosen_input)
+ return cell_out, chosen_mask
+
+
+class Node(mutables.MutableScope):
+ def __init__(self, node_name, prev_node_names, channels):
+ super().__init__(node_name)
+ self.cell_x = Cell(node_name + "_x", prev_node_names, channels)
+ self.cell_y = Cell(node_name + "_y", prev_node_names, channels)
+
+ def forward(self, prev_layers):
+ out_x, mask_x = self.cell_x(prev_layers)
+ out_y, mask_y = self.cell_y(prev_layers)
+ return out_x + out_y, mask_x | mask_y
+
+
+class Calibration(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.process = None
+ if in_channels != out_channels:
+ self.process = StdConv(in_channels, out_channels)
+
+ def forward(self, x):
+ if self.process is None:
+ return x
+ return self.process(x)
+
+
+class ENASMicroLayer(nn.Module):
+ """
+ Builtin EnasMicroLayer. Micro search designs only one building block whose architecture is repeated
+ throughout the final architecture. A cell has ``num_nodes`` nodes and searches the topology and
+ operations among them in RL way. The first two nodes in a layer stand for the outputs from previous
+ previous layer and previous layer respectively. For the following nodes, the controller chooses
+ two previous nodes and applies two operations respectively for each node. Nodes that are not served
+ as input for any other node are viewed as the output of the layer. If there are multiple output nodes,
+ the model will calculate the average of these nodes as the layer output. Every node's output has ``out_channels``
+ channels so the result of the layer has the same number of channels as each node.
+
+ Parameters
+ ---
+ num_nodes: int
+ the number of nodes contained in this layer
+ in_channles_pp: int
+ the number of previous previous layer's output channels
+ in_channels_p: int
+ the number of previous layer's output channels
+ out_channels: int
+ output channels of this layer
+ reduction: bool
+ is reduction operation empolyed before this layer
+ """
+ def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduction):
+ super().__init__()
+ print(in_channels_pp, in_channels_p, out_channels, reduction)
+ self.reduction = reduction
+ if self.reduction:
+ self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False)
+ self.reduce1 = FactorizedReduce(in_channels_p, out_channels, affine=False)
+ in_channels_pp = in_channels_p = out_channels
+ self.preproc0 = Calibration(in_channels_pp, out_channels)
+ self.preproc1 = Calibration(in_channels_p, out_channels)
+
+ self.num_nodes = num_nodes
+ name_prefix = "reduce" if reduction else "normal"
+ self.nodes = nn.ModuleList()
+ node_labels = [mutables.InputChoice.NO_KEY, mutables.InputChoice.NO_KEY]
+ for i in range(num_nodes):
+ node_labels.append("{}_node_{}".format(name_prefix, i))
+ self.nodes.append(Node(node_labels[-1], node_labels[:-1], out_channels))
+ self.final_conv_w = nn.Parameter(torch.zeros(out_channels, self.num_nodes + 2, out_channels, 1, 1),
+ requires_grad=True)
+ self.bn = nn.BatchNorm2d(out_channels, affine=False)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.kaiming_normal_(self.final_conv_w)
+
+ def forward(self, pprev, prev):
+ """
+ Parameters
+ ---
+ pprev: torch.Tensor
+ the output of the previous previous layer
+ prev: torch.Tensor
+ the output of the previous previous layer
+ """
+ if self.reduction:
+ pprev, prev = self.reduce0(pprev), self.reduce1(prev)
+ pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev)
+
+ prev_nodes_out = [pprev_, prev_]
+ nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device)
+ for i in range(self.num_nodes):
+ node_out, mask = self.nodes[i](prev_nodes_out)
+ nodes_used_mask[:mask.size(0)] |= mask.to(node_out.device)
+ prev_nodes_out.append(node_out)
+
+ unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1)
+ unused_nodes = F.relu(unused_nodes)
+ conv_weight = self.final_conv_w[:, ~nodes_used_mask, :, :, :]
+ conv_weight = conv_weight.view(conv_weight.size(0), -1, 1, 1)
+ out = F.conv2d(unused_nodes, conv_weight)
+ return prev, self.bn(out)
+
+
+class ENASMacroLayer(mutables.MutableScope):
+ """
+ Builtin ENAS Marco Layer. With search space changing to layer level, the controller decides
+ what operation is employed and the previous layer to connect to for skip connections. The model
+ is made up of the same layers but the choice of each layer may be different.
+
+ Parameters
+ ---
+ key: str
+ the name of this layer
+ prev_labels: str
+ names of all previous layers
+ in_filters: int
+ the number of input channels
+ out_filters:
+ the number of output channels
+ """
+ def __init__(self, key, prev_labels, in_filters, out_filters):
+ super().__init__(key)
+ self.in_filters = in_filters
+ self.out_filters = out_filters
+ self.mutable = mutables.LayerChoice([
+ ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False),
+ ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True),
+ ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False),
+ ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True),
+ PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
+ PoolBranch('max', in_filters, out_filters, 3, 1, 1)
+ ])
+ if prev_labels > 0:
+ self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None)
+ else:
+ self.skipconnect = None
+ self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
+
+ def forward(self, prev_list):
+ """
+ Parameters
+ ---
+ prev_list: list
+ The cell selects the last element of the list as input and applies an operation on it.
+ The cell chooses none/one/multiple tensor(s) as SkipConnect(s) from the list excluding
+ the last element.
+ """
+ out = self.mutable(prev_list[-1])
+ if self.skipconnect is not None:
+ connection = self.skipconnect(prev_list[:-1])
+ if connection is not None:
+ out += connection
+ return self.batch_norm(out)
+
+
+class ENASMacroGeneralModel(nn.Module):
+ """
+ The network is made up by stacking ENASMacroLayer. The Macro search space contains these layers.
+ Each layer chooses an operation from predefined ones and SkipConnect then forms a network.
+
+ Parameters
+ ---
+ num_layers: int
+ The number of layers contained in the network.
+ out_filters: int
+ The number of each layer's output channels.
+ in_channel: int
+ The number of input's channels.
+ num_classes: int
+ The number of classes for classification.
+ dropout_rate: float
+ Dropout layer's dropout rate before the final dense layer.
+ """
+ def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10,
+ dropout_rate=0.0):
+ super().__init__()
+ self.num_layers = num_layers
+ self.num_classes = num_classes
+ self.out_filters = out_filters
+
+ self.stem = nn.Sequential(
+ nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False),
+ nn.BatchNorm2d(out_filters)
+ )
+
+ pool_distance = self.num_layers // 3
+ self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1]
+ self.dropout_rate = dropout_rate
+ self.dropout = nn.Dropout(self.dropout_rate)
+
+ self.layers = nn.ModuleList()
+ self.pool_layers = nn.ModuleList()
+ labels = []
+ for layer_id in range(self.num_layers):
+ labels.append("layer_{}".format(layer_id))
+ if layer_id in self.pool_layers_idx:
+ self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters))
+ self.layers.append(ENASMacroLayer(labels[-1], labels[:-1], self.out_filters, self.out_filters))
+
+ self.gap = nn.AdaptiveAvgPool2d(1)
+ self.dense = nn.Linear(self.out_filters, self.num_classes)
+
+ def forward(self, x):
+ """
+ Parameters
+ ---
+ x: torch.Tensor
+ the input of the network
+ """
+ bs = x.size(0)
+ cur = self.stem(x)
+
+ layers = [cur]
+
+ for layer_id in range(self.num_layers):
+ cur = self.layers[layer_id](layers)
+ layers.append(cur)
+ if layer_id in self.pool_layers_idx:
+ for i, layer in enumerate(layers):
+ layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer)
+ cur = layers[-1]
+
+ cur = self.gap(cur).view(bs, -1)
+ cur = self.dropout(cur)
+ logits = self.dense(cur)
+ return logits
diff --git a/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_ops.py b/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_ops.py
new file mode 100644
index 0000000000..21ecc2da79
--- /dev/null
+++ b/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_ops.py
@@ -0,0 +1,171 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import torch
+import torch.nn as nn
+
+
+class StdConv(nn.Module):
+ def __init__(self, C_in, C_out):
+ super(StdConv, self).__init__()
+ self.conv = nn.Sequential(
+ nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(C_out, affine=False),
+ nn.ReLU()
+ )
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class PoolBranch(nn.Module):
+ """
+ Pooling structure for Macro search. First pass through a 1x1 Conv, then pooling operation followed by BatchNorm2d.
+
+ Parameters
+ ---
+ pool_type: str
+ only accept ``max`` for MaxPool and ``avg`` for AvgPool
+ C_in: int
+ the number of input channels
+ C_out: int
+ the number of output channels
+ kernal_size: int
+ size of the convolving kernel
+ stride: int
+ stride of the convolution
+ padding: int
+ zero-padding added to both sides of the input
+ """
+ def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False):
+ super().__init__()
+ self.preproc = StdConv(C_in, C_out)
+ self.pool = Pool(pool_type, kernel_size, stride, padding)
+ self.bn = nn.BatchNorm2d(C_out, affine=affine)
+
+ def forward(self, x):
+ out = self.preproc(x)
+ out = self.pool(out)
+ out = self.bn(out)
+ return out
+
+
+class SeparableConv(nn.Module):
+ def __init__(self, C_in, C_out, kernel_size, stride, padding):
+ super(SeparableConv, self).__init__()
+ self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride,
+ groups=C_in, bias=False)
+ self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False)
+
+ def forward(self, x):
+ out = self.depthwise(x)
+ out = self.pointwise(out)
+ return out
+
+
+class ConvBranch(nn.Module):
+ """
+ Conv structure for Macro search. First pass through a 1x1 Conv,
+ then Conv operation with kernal_size equals 3 or 5 followed by BatchNorm and ReLU.
+
+ Parameters
+ ---
+ C_in: int
+ the number of input channels
+ C_out: int
+ the number of output channels
+ kernal_size: int
+ size of the convolving kernel
+ stride: int
+ stride of the convolution
+ padding: int
+ zero-padding added to both sides of the input
+ separable: True
+ is separable Conv is used
+ """
+ def __init__(self, C_in, C_out, kernel_size, stride, padding, separable):
+ super(ConvBranch, self).__init__()
+ self.preproc = StdConv(C_in, C_out)
+ if separable:
+ self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding)
+ else:
+ self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding)
+ self.postproc = nn.Sequential(
+ nn.BatchNorm2d(C_out, affine=False),
+ nn.ReLU()
+ )
+
+ def forward(self, x):
+ out = self.preproc(x)
+ out = self.conv(out)
+ out = self.postproc(out)
+ return out
+
+
+class FactorizedReduce(nn.Module):
+ def __init__(self, C_in, C_out, affine=False):
+ super().__init__()
+ self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
+ self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
+ self.bn = nn.BatchNorm2d(C_out, affine=affine)
+
+ def forward(self, x):
+ out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
+ out = self.bn(out)
+ return out
+
+
+class Pool(nn.Module):
+ """
+ Pooling structure
+
+ Parameters
+ ---
+ pool_type: str
+ only accept ``max`` for MaxPool and ``avg`` for AvgPool
+ kernal_size: int
+ size of the convolving kernel
+ stride: int
+ stride of the convolution
+ padding: int
+ zero-padding added to both sides of the input
+ """
+ def __init__(self, pool_type, kernel_size, stride, padding):
+ super().__init__()
+ if pool_type.lower() == 'max':
+ self.pool = nn.MaxPool2d(kernel_size, stride, padding)
+ elif pool_type.lower() == 'avg':
+ self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
+ else:
+ raise ValueError()
+
+ def forward(self, x):
+ return self.pool(x)
+
+
+class SepConvBN(nn.Module):
+ """
+ Implement SepConv followed by BatchNorm. The structure is ReLU ==> SepConv ==> BN.
+
+ Parameters
+ ---
+ C_in: int
+ the number of imput channels
+ C_out: int
+ the number of output channels
+ kernal_size: int
+ size of the convolving kernel
+ padding: int
+ zero-padding added to both sides of the input
+ """
+ def __init__(self, C_in, C_out, kernel_size, padding):
+ super().__init__()
+ self.relu = nn.ReLU()
+ self.conv = SeparableConv(C_in, C_out, kernel_size, 1, padding)
+ self.bn = nn.BatchNorm2d(C_out, affine=True)
+
+ def forward(self, x):
+ x = self.relu(x)
+ x = self.conv(x)
+ x = self.bn(x)
+ return x