diff --git a/.gitignore b/.gitignore index e96b14efc6..83049a476e 100644 --- a/.gitignore +++ b/.gitignore @@ -80,6 +80,7 @@ venv.bak/ # VSCode .vscode +.vs # In case you place source code in ~/nni/ /experiments diff --git a/docs/en_US/AdvancedFeature/MultiPhase.md b/docs/en_US/AdvancedFeature/MultiPhase.md index c9727bcdcc..4cdb3a7a99 100644 --- a/docs/en_US/AdvancedFeature/MultiPhase.md +++ b/docs/en_US/AdvancedFeature/MultiPhase.md @@ -79,7 +79,7 @@ With this information, the tuner could know which trial is requesting a configur ### Tuners support multi-phase experiments: -[TPE](../Tuner/HyperoptTuner.md), [Random](../Tuner/HyperoptTuner.md), [Anneal](../Tuner/HyperoptTuner.md), [Evolution](../Tuner/EvolutionTuner.md), [SMAC](../Tuner/SmacTuner.md), [NetworkMorphism](../Tuner/NetworkmorphismTuner.md), [MetisTuner](../Tuner/MetisTuner.md), [BOHB](../Tuner/BohbAdvisor.md), [Hyperband](../Tuner/HyperbandAdvisor.md), [ENAS tuner](https://github.com/countif/enas_nni/blob/master/nni/examples/tuners/enas/nni_controller_ptb.py). +[TPE](../Tuner/HyperoptTuner.md), [Random](../Tuner/HyperoptTuner.md), [Anneal](../Tuner/HyperoptTuner.md), [Evolution](../Tuner/EvolutionTuner.md), [SMAC](../Tuner/SmacTuner.md), [NetworkMorphism](../Tuner/NetworkmorphismTuner.md), [MetisTuner](../Tuner/MetisTuner.md), [BOHB](../Tuner/BohbAdvisor.md), [Hyperband](../Tuner/HyperbandAdvisor.md). ### Training services support multi-phase experiment: [Local Machine](../TrainingService/LocalMode.md), [Remote Servers](../TrainingService/RemoteMachineMode.md), [OpenPAI](../TrainingService/PaiMode.md) diff --git a/docs/en_US/NAS/Overview.md b/docs/en_US/NAS/Overview.md new file mode 100644 index 0000000000..4e48483df3 --- /dev/null +++ b/docs/en_US/NAS/Overview.md @@ -0,0 +1,77 @@ +# Neural Architecture Search (NAS) on NNI + +Automatic neural architecture search is taking an increasingly important role on finding better models. Recent research works have proved the feasibility of automatic NAS, and also found some models that could beat manually designed and tuned models. Some of representative works are [NASNet][2], [ENAS][1], [DARTS][3], [Network Morphism][4], and [Evolution][5]. There are new innovations keeping emerging. + +However, it takes great efforts to implement NAS algorithms, and it is hard to reuse code base of existing algorithms in new one. To facilitate NAS innovations (e.g., design and implement new NAS models, compare different NAS models side-by-side), an easy-to-use and flexible programming interface is crucial. + +With this motivation, our ambition is to provide a unified architecture in NNI, to accelerate innovations on NAS, and apply state-of-art algorithms on real world problems faster. + +## Supported algorithms + +NNI supports below NAS algorithms now and being adding more. User can reproduce an algorithm or use it on owned dataset. we also encourage user to implement other algorithms with [NNI API](#use-nni-api), to benefit more people. + +Note, these algorithms run standalone without nnictl, and supports PyTorch only. + +### Dependencies + +* Install latest NNI +* PyTorch 1.2+ +* git + +### DARTS + +The main contribution of [DARTS: Differentiable Architecture Search][3] on algorithm is to introduce a novel algorithm for differentiable network architecture search on bilevel optimization. + +#### Usage + +```bash +# In case NNI code is not cloned. If the code is cloned already, ignore this line and enter code folder. +git clone https://github.com/Microsoft/nni.git + +# search the best architecture +cd examples/nas/darts +python3 search.py + +# train the best architecture +python3 retrain.py --arc-checkpoint ./checkpoints/epoch_49.json +``` + +### P-DARTS + +[Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation](https://arxiv.org/abs/1904.12760) bases on [DARTS](#DARTS). It's contribution on algorithm is to introduce an efficient algorithm which allows the depth of searched architectures to grow gradually during the training procedure. + +#### Usage + +```bash +# In case NNI code is not cloned. If the code is cloned already, ignore this line and enter code folder. +git clone https://github.com/Microsoft/nni.git + +# search the best architecture +cd examples/nas/pdarts +python3 search.py + +# train the best architecture, it's the same progress as darts. +cd examples/nas/darts +python3 retrain.py --arc-checkpoint ./checkpoints/epoch_2.json +``` + +## Use NNI API + +NOTE, we are trying to support various NAS algorithms with unified programming interface, and it's in very experimental stage. It means the current programing interface may be updated significantly. + +*previous [NAS annotation](../AdvancedFeature/GeneralNasInterfaces.md) interface will be deprecated soon.* + +### Programming interface + +The programming interface of designing and searching a model is often demanded in two scenarios. + +1. When designing a neural network, there may be multiple operation choices on a layer, sub-model, or connection, and it's undetermined which one or combination performs best. So, it needs an easy way to express the candidate layers or sub-models. +2. When applying NAS on a neural network, it needs an unified way to express the search space of architectures, so that it doesn't need to update trial code for different searching algorithms. + +NNI proposed API is [here](https://github.com/microsoft/nni/tree/master/src/sdk/pynni/nni/nas/pytorch). And [here](https://github.com/microsoft/nni/tree/master/examples/nas/darts) is an example of NAS implementation, which bases on NNI proposed interface. + +[1]: https://arxiv.org/abs/1802.03268 +[2]: https://arxiv.org/abs/1707.07012 +[3]: https://arxiv.org/abs/1806.09055 +[4]: https://arxiv.org/abs/1806.10282 +[5]: https://arxiv.org/abs/1703.01041 diff --git a/docs/en_US/Tutorial/SearchSpaceSpec.md b/docs/en_US/Tutorial/SearchSpaceSpec.md index fd1781716f..eb5d39315c 100644 --- a/docs/en_US/Tutorial/SearchSpaceSpec.md +++ b/docs/en_US/Tutorial/SearchSpaceSpec.md @@ -73,12 +73,6 @@ All types of sampling strategies and their parameter are listed here: * Which means the variable value is a value like `round(exp(normal(mu, sigma)) / q) * q` * Suitable for a discrete variable with respect to which the objective is smooth and gets smoother with the size of the variable, which is bounded from one side. -* `{"_type": "mutable_layer", "_value": {mutable_layer_infomation}}` - * Type for [Neural Architecture Search Space][1]. Value is also a dictionary, which contains key-value pairs representing respectively name and search space of each mutable_layer. - * For now, users can only use this type of search space with annotation, which means that there is no need to define a json file for search space since it will be automatically generated according to the annotation in trial code. - * The following HPO tuners can be adapted to tune this search space: TPE, Random, Anneal, Evolution, Grid Search, - Hyperband and BOHB. - * For detailed usage, please refer to [General NAS Interfaces][1]. ## Search Space Types Supported by Each Tuner @@ -105,5 +99,3 @@ Known Limitations: * Only Random Search/TPE/Anneal/Evolution tuner supports nested search space * We do not support nested search space "Hyper Parameter" in visualization now, the enhancement is being considered in [#1110](https://github.com/microsoft/nni/issues/1110), any suggestions or discussions or contributions are warmly welcomed - -[1]: ../AdvancedFeature/GeneralNasInterfaces.md diff --git a/docs/en_US/advanced.rst b/docs/en_US/advanced.rst index d9192cc869..e38f634969 100644 --- a/docs/en_US/advanced.rst +++ b/docs/en_US/advanced.rst @@ -3,5 +3,3 @@ Advanced Features .. toctree:: MultiPhase<./AdvancedFeature/MultiPhase> - AdvancedNas<./AdvancedFeature/AdvancedNas> - NAS Programming Interface<./AdvancedFeature/GeneralNasInterfaces> \ No newline at end of file diff --git a/examples/nas/.gitignore b/examples/nas/.gitignore new file mode 100644 index 0000000000..8eeb0c2a3f --- /dev/null +++ b/examples/nas/.gitignore @@ -0,0 +1,3 @@ +data +checkpoints +runs diff --git a/examples/nas/darts/datasets.py b/examples/nas/darts/datasets.py new file mode 100644 index 0000000000..c5861f16d3 --- /dev/null +++ b/examples/nas/darts/datasets.py @@ -0,0 +1,53 @@ +import numpy as np +import torch +from torchvision import transforms +from torchvision.datasets import CIFAR10 + + +class Cutout(object): + def __init__(self, length): + self.length = length + + def __call__(self, img): + h, w = img.size(1), img.size(2) + mask = np.ones((h, w), np.float32) + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1: y2, x1: x2] = 0. + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img *= mask + + return img + + +def get_dataset(cls, cutout_length=0): + MEAN = [0.49139968, 0.48215827, 0.44653124] + STD = [0.24703233, 0.24348505, 0.26158768] + transf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip() + ] + normalize = [ + transforms.ToTensor(), + transforms.Normalize(MEAN, STD) + ] + cutout = [] + if cutout_length > 0: + cutout.append(Cutout(cutout_length)) + + train_transform = transforms.Compose(transf + normalize + cutout) + valid_transform = transforms.Compose(normalize) + + if cls == "cifar10": + dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform) + dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform) + else: + raise NotImplementedError + return dataset_train, dataset_valid diff --git a/examples/nas/darts/model.py b/examples/nas/darts/model.py new file mode 100644 index 0000000000..a6a86f4a72 --- /dev/null +++ b/examples/nas/darts/model.py @@ -0,0 +1,157 @@ +import torch +import torch.nn as nn + +import ops +from nni.nas.pytorch import mutables + + +class AuxiliaryHead(nn.Module): + """ Auxiliary head in 2/3 place of network to let the gradient flow well """ + + def __init__(self, input_size, C, n_classes): + """ assuming input size 7x7 or 8x8 """ + assert input_size in [7, 8] + super().__init__() + self.net = nn.Sequential( + nn.ReLU(inplace=True), + nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out + nn.Conv2d(C, 128, kernel_size=1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out + nn.BatchNorm2d(768), + nn.ReLU(inplace=True) + ) + self.linear = nn.Linear(768, n_classes) + + def forward(self, x): + out = self.net(x) + out = out.view(out.size(0), -1) # flatten + logits = self.linear(out) + return logits + + +class Node(nn.Module): + def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect): + super().__init__() + self.ops = nn.ModuleList() + choice_keys = [] + for i in range(num_prev_nodes): + stride = 2 if i < num_downsample_connect else 1 + choice_keys.append("{}_p{}".format(node_id, i)) + self.ops.append( + mutables.LayerChoice( + [ + ops.PoolBN('max', channels, 3, stride, 1, affine=False), + ops.PoolBN('avg', channels, 3, stride, 1, affine=False), + nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False), + ops.SepConv(channels, channels, 3, stride, 1, affine=False), + ops.SepConv(channels, channels, 5, stride, 2, affine=False), + ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False), + ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False) + ], + key=choice_keys[-1])) + self.drop_path = ops.DropPath_() + self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id)) + + def forward(self, prev_nodes): + assert len(self.ops) == len(prev_nodes) + out = [op(node) for op, node in zip(self.ops, prev_nodes)] + out = [self.drop_path(o) if o is not None else None for o in out] + return self.input_switch(out) + + +class Cell(nn.Module): + + def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): + super().__init__() + self.reduction = reduction + self.n_nodes = n_nodes + + # If previous cell is reduction cell, current input size does not match with + # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. + if reduction_p: + self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False) + else: + self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False) + self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False) + + # generate dag + self.mutable_ops = nn.ModuleList() + for depth in range(2, self.n_nodes + 2): + self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth), + depth, channels, 2 if reduction else 0)) + + def forward(self, s0, s1): + # s0, s1 are the outputs of previous previous cell and previous cell, respectively. + tensors = [self.preproc0(s0), self.preproc1(s1)] + for node in self.mutable_ops: + cur_tensor = node(tensors) + tensors.append(cur_tensor) + + output = torch.cat(tensors[2:], dim=1) + return output + + +class CNN(nn.Module): + + def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4, + stem_multiplier=3, auxiliary=False): + super().__init__() + self.in_channels = in_channels + self.channels = channels + self.n_classes = n_classes + self.n_layers = n_layers + self.aux_pos = 2 * n_layers // 3 if auxiliary else -1 + + c_cur = stem_multiplier * self.channels + self.stem = nn.Sequential( + nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), + nn.BatchNorm2d(c_cur) + ) + + # for the first cell, stem is used for both s0 and s1 + # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. + channels_pp, channels_p, c_cur = c_cur, c_cur, channels + + self.cells = nn.ModuleList() + reduction_p, reduction = False, False + for i in range(n_layers): + reduction_p, reduction = reduction, False + # Reduce featuremap size and double channels in 1/3 and 2/3 layer. + if i in [n_layers // 3, 2 * n_layers // 3]: + c_cur *= 2 + reduction = True + + cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction) + self.cells.append(cell) + c_cur_out = c_cur * n_nodes + channels_pp, channels_p = channels_p, c_cur_out + + if i == self.aux_pos: + self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes) + + self.gap = nn.AdaptiveAvgPool2d(1) + self.linear = nn.Linear(channels_p, n_classes) + + def forward(self, x): + s0 = s1 = self.stem(x) + + aux_logits = None + for i, cell in enumerate(self.cells): + s0, s1 = s1, cell(s0, s1) + if i == self.aux_pos and self.training: + aux_logits = self.aux_head(s1) + + out = self.gap(s1) + out = out.view(out.size(0), -1) # flatten + logits = self.linear(out) + + if aux_logits is not None: + return logits, aux_logits + return logits + + def drop_path_prob(self, p): + for module in self.modules(): + if isinstance(module, ops.DropPath_): + module.p = p diff --git a/examples/nas/darts/ops.py b/examples/nas/darts/ops.py new file mode 100644 index 0000000000..9b74c346f9 --- /dev/null +++ b/examples/nas/darts/ops.py @@ -0,0 +1,136 @@ +import torch +import torch.nn as nn + + +class DropPath_(nn.Module): + def __init__(self, p=0.): + """ + DropPath is inplace module. + + Parameters + ---------- + p : float + Probability of an path to be zeroed. + """ + super().__init__() + self.p = p + + def extra_repr(self): + return 'p={}, inplace'.format(self.p) + + def forward(self, x): + if self.training and self.p > 0.: + keep_prob = 1. - self.p + # per data point mask + mask = torch.zeros((x.size(0), 1, 1, 1), device=x.device).bernoulli_(keep_prob) + x.div_(keep_prob).mul_(mask) + + return x + + +class PoolBN(nn.Module): + """ + AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`. + """ + def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True): + super().__init__() + if pool_type.lower() == 'max': + self.pool = nn.MaxPool2d(kernel_size, stride, padding) + elif pool_type.lower() == 'avg': + self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) + else: + raise ValueError() + + self.bn = nn.BatchNorm2d(C, affine=affine) + + def forward(self, x): + out = self.pool(x) + out = self.bn(out) + return out + + +class StdConv(nn.Module): + """ + Standard conv: ReLU - Conv - BN + """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): + super().__init__() + self.net = nn.Sequential( + nn.ReLU(), + nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False), + nn.BatchNorm2d(C_out, affine=affine) + ) + + def forward(self, x): + return self.net(x) + + +class FacConv(nn.Module): + """ + Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN + """ + def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True): + super().__init__() + self.net = nn.Sequential( + nn.ReLU(), + nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False), + nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False), + nn.BatchNorm2d(C_out, affine=affine) + ) + + def forward(self, x): + return self.net(x) + + +class DilConv(nn.Module): + """ + (Dilated) depthwise separable conv. + ReLU - (Dilated) depthwise separable - Pointwise - BN. + If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field. + """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): + super().__init__() + self.net = nn.Sequential( + nn.ReLU(), + nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, + bias=False), + nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(C_out, affine=affine) + ) + + def forward(self, x): + return self.net(x) + + +class SepConv(nn.Module): + """ + Depthwise separable conv. + DilConv(dilation=1) * 2. + """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): + super().__init__() + self.net = nn.Sequential( + DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine), + DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine) + ) + + def forward(self, x): + return self.net(x) + + +class FactorizedReduce(nn.Module): + """ + Reduce feature map size by factorized pointwise (stride=2). + """ + def __init__(self, C_in, C_out, affine=True): + super().__init__() + self.relu = nn.ReLU() + self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.bn = nn.BatchNorm2d(C_out, affine=affine) + + def forward(self, x): + x = self.relu(x) + out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) + out = self.bn(out) + return out diff --git a/examples/nas/darts/retrain.py b/examples/nas/darts/retrain.py new file mode 100644 index 0000000000..904f3248fc --- /dev/null +++ b/examples/nas/darts/retrain.py @@ -0,0 +1,163 @@ +import logging +import time +from argparse import ArgumentParser + +import torch +import torch.nn as nn +from nni.nas.pytorch.fixed import apply_fixed_architecture +from nni.nas.pytorch.utils import AverageMeter +from torch.utils.tensorboard import SummaryWriter + +import datasets +import utils +from model import CNN + +logger = logging.getLogger() + +fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' +logging.Formatter.converter = time.localtime +formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p') + +std_out_info = logging.StreamHandler() +std_out_info.setFormatter(formatter) +logger.setLevel(logging.INFO) +logger.addHandler(std_out_info) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +writer = SummaryWriter() + + +def train(config, train_loader, model, optimizer, criterion, epoch): + top1 = AverageMeter("top1") + top5 = AverageMeter("top5") + losses = AverageMeter("losses") + + cur_step = epoch * len(train_loader) + cur_lr = optimizer.param_groups[0]['lr'] + logger.info("Epoch %d LR %.6f", epoch, cur_lr) + writer.add_scalar("lr", cur_lr, global_step=cur_step) + + model.train() + + for step, (x, y) in enumerate(train_loader): + x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) + bs = x.size(0) + + optimizer.zero_grad() + logits, aux_logits = model(x) + loss = criterion(logits, y) + if config.aux_weight > 0.: + loss += config.aux_weight * criterion(aux_logits, y) + loss.backward() + # gradient clipping + nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + optimizer.step() + + accuracy = utils.accuracy(logits, y, topk=(1, 5)) + losses.update(loss.item(), bs) + top1.update(accuracy["acc1"], bs) + top5.update(accuracy["acc5"], bs) + writer.add_scalar("loss/train", loss.item(), global_step=cur_step) + writer.add_scalar("acc1/train", accuracy["acc1"], global_step=cur_step) + writer.add_scalar("acc5/train", accuracy["acc5"], global_step=cur_step) + + if step % config.log_frequency == 0 or step == len(train_loader) - 1: + logger.info( + "Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " + "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( + epoch + 1, config.epochs, step, len(train_loader) - 1, losses=losses, + top1=top1, top5=top5)) + + cur_step += 1 + + logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg)) + + +def validate(config, valid_loader, model, criterion, epoch, cur_step): + top1 = AverageMeter("top1") + top5 = AverageMeter("top5") + losses = AverageMeter("losses") + + model.eval() + + with torch.no_grad(): + for step, (X, y) in enumerate(valid_loader): + X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) + bs = X.size(0) + + logits = model(X) + loss = criterion(logits, y) + + accuracy = utils.accuracy(logits, y, topk=(1, 5)) + losses.update(loss.item(), bs) + top1.update(accuracy["acc1"], bs) + top5.update(accuracy["acc5"], bs) + + if step % config.log_frequency == 0 or step == len(valid_loader) - 1: + logger.info( + "Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " + "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( + epoch + 1, config.epochs, step, len(valid_loader) - 1, losses=losses, + top1=top1, top5=top5)) + + writer.add_scalar("loss/test", losses.avg, global_step=cur_step) + writer.add_scalar("acc1/test", top1.avg, global_step=cur_step) + writer.add_scalar("acc5/test", top5.avg, global_step=cur_step) + + logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg)) + + return top1.avg + + +if __name__ == "__main__": + parser = ArgumentParser("darts") + parser.add_argument("--layers", default=20, type=int) + parser.add_argument("--batch-size", default=96, type=int) + parser.add_argument("--log-frequency", default=10, type=int) + parser.add_argument("--epochs", default=600, type=int) + parser.add_argument("--aux-weight", default=0.4, type=float) + parser.add_argument("--drop-path-prob", default=0.2, type=float) + parser.add_argument("--workers", default=4) + parser.add_argument("--grad-clip", default=5., type=float) + parser.add_argument("--arc-checkpoint", default="./checkpoints/epoch_0.json") + + args = parser.parse_args() + dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16) + + model = CNN(32, 3, 36, 10, args.layers, auxiliary=True) + apply_fixed_architecture(model, args.arc_checkpoint, device=device) + criterion = nn.CrossEntropyLoss() + + model.to(device) + criterion.to(device) + + optimizer = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=1E-6) + + train_loader = torch.utils.data.DataLoader(dataset_train, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, + pin_memory=True) + valid_loader = torch.utils.data.DataLoader(dataset_valid, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=True) + + best_top1 = 0. + for epoch in range(args.epochs): + drop_prob = args.drop_path_prob * epoch / args.epochs + model.drop_path_prob(drop_prob) + + # training + train(args, train_loader, model, optimizer, criterion, epoch) + + # validation + cur_step = (epoch + 1) * len(train_loader) + top1 = validate(args, valid_loader, model, criterion, epoch, cur_step) + best_top1 = max(best_top1, top1) + + lr_scheduler.step() + + logger.info("Final best Prec@1 = {:.4%}".format(best_top1)) diff --git a/examples/nas/darts/search.py b/examples/nas/darts/search.py new file mode 100644 index 0000000000..f25db7c7e4 --- /dev/null +++ b/examples/nas/darts/search.py @@ -0,0 +1,53 @@ +import logging +import time +from argparse import ArgumentParser + +import torch +import torch.nn as nn + +import datasets +from model import CNN +from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback +from nni.nas.pytorch.darts import DartsTrainer +from utils import accuracy + +logger = logging.getLogger() + +fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' +logging.Formatter.converter = time.localtime +formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p') + +std_out_info = logging.StreamHandler() +std_out_info.setFormatter(formatter) +logger.setLevel(logging.INFO) +logger.addHandler(std_out_info) + +if __name__ == "__main__": + parser = ArgumentParser("darts") + parser.add_argument("--layers", default=8, type=int) + parser.add_argument("--batch-size", default=64, type=int) + parser.add_argument("--log-frequency", default=10, type=int) + parser.add_argument("--epochs", default=50, type=int) + parser.add_argument("--unrolled", default=False, action="store_true") + args = parser.parse_args() + + dataset_train, dataset_valid = datasets.get_dataset("cifar10") + + model = CNN(32, 3, 16, 10, args.layers) + criterion = nn.CrossEntropyLoss() + + optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001) + + trainer = DartsTrainer(model, + loss=criterion, + metrics=lambda output, target: accuracy(output, target, topk=(1,)), + optimizer=optim, + num_epochs=args.epochs, + dataset_train=dataset_train, + dataset_valid=dataset_valid, + batch_size=args.batch_size, + log_frequency=args.log_frequency, + unrolled=args.unrolled, + callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")]) + trainer.train() diff --git a/examples/nas/darts/utils.py b/examples/nas/darts/utils.py new file mode 100644 index 0000000000..2aac457ad1 --- /dev/null +++ b/examples/nas/darts/utils.py @@ -0,0 +1,18 @@ +def accuracy(output, target, topk=(1,)): + """ Computes the precision@k for the specified values of k """ + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + # one-hot case + if target.ndimension() > 1: + target = target.max(1)[1] + + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = dict() + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() + return res \ No newline at end of file diff --git a/examples/nas/enas/datasets.py b/examples/nas/enas/datasets.py new file mode 100644 index 0000000000..8fe0ab0fbf --- /dev/null +++ b/examples/nas/enas/datasets.py @@ -0,0 +1,25 @@ +from torchvision import transforms +from torchvision.datasets import CIFAR10 + + +def get_dataset(cls): + MEAN = [0.49139968, 0.48215827, 0.44653124] + STD = [0.24703233, 0.24348505, 0.26158768] + transf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip() + ] + normalize = [ + transforms.ToTensor(), + transforms.Normalize(MEAN, STD) + ] + + train_transform = transforms.Compose(transf + normalize) + valid_transform = transforms.Compose(normalize) + + if cls == "cifar10": + dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform) + dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform) + else: + raise NotImplementedError + return dataset_train, dataset_valid diff --git a/examples/nas/enas/macro.py b/examples/nas/enas/macro.py new file mode 100644 index 0000000000..bfac9b17c9 --- /dev/null +++ b/examples/nas/enas/macro.py @@ -0,0 +1,83 @@ +import torch.nn as nn + +from nni.nas.pytorch import mutables +from ops import FactorizedReduce, ConvBranch, PoolBranch + + +class ENASLayer(mutables.MutableScope): + + def __init__(self, key, prev_labels, in_filters, out_filters): + super().__init__(key) + self.in_filters = in_filters + self.out_filters = out_filters + self.mutable = mutables.LayerChoice([ + ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False), + ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True), + ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False), + ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True), + PoolBranch('avg', in_filters, out_filters, 3, 1, 1), + PoolBranch('max', in_filters, out_filters, 3, 1, 1) + ]) + if len(prev_labels) > 0: + self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None) + else: + self.skipconnect = None + self.batch_norm = nn.BatchNorm2d(out_filters, affine=False) + + def forward(self, prev_layers): + out = self.mutable(prev_layers[-1]) + if self.skipconnect is not None: + connection = self.skipconnect(prev_layers[:-1]) + if connection is not None: + out += connection + return self.batch_norm(out) + + +class GeneralNetwork(nn.Module): + def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10, + dropout_rate=0.0): + super().__init__() + self.num_layers = num_layers + self.num_classes = num_classes + self.out_filters = out_filters + + self.stem = nn.Sequential( + nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_filters) + ) + + pool_distance = self.num_layers // 3 + self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1] + self.dropout_rate = dropout_rate + self.dropout = nn.Dropout(self.dropout_rate) + + self.layers = nn.ModuleList() + self.pool_layers = nn.ModuleList() + labels = [] + for layer_id in range(self.num_layers): + labels.append("layer_{}".format(layer_id)) + if layer_id in self.pool_layers_idx: + self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters)) + self.layers.append(ENASLayer(labels[-1], labels[:-1], self.out_filters, self.out_filters)) + + self.gap = nn.AdaptiveAvgPool2d(1) + self.dense = nn.Linear(self.out_filters, self.num_classes) + + def forward(self, x): + bs = x.size(0) + cur = self.stem(x) + + layers = [cur] + + for layer_id in range(self.num_layers): + cur = self.layers[layer_id](layers) + layers.append(cur) + if layer_id in self.pool_layers_idx: + for i, layer in enumerate(layers): + layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer) + cur = layers[-1] + + cur = self.gap(cur).view(bs, -1) + cur = self.dropout(cur) + logits = self.dense(cur) + return logits diff --git a/examples/nas/enas/micro.py b/examples/nas/enas/micro.py new file mode 100644 index 0000000000..fabd3919ca --- /dev/null +++ b/examples/nas/enas/micro.py @@ -0,0 +1,184 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nni.nas.pytorch import mutables +from ops import FactorizedReduce, StdConv, SepConvBN, Pool + + +class AuxiliaryHead(nn.Module): + def __init__(self, in_channels, num_classes): + super().__init__() + self.in_channels = in_channels + self.num_classes = num_classes + self.pooling = nn.Sequential( + nn.ReLU(), + nn.AvgPool2d(5, 3, 2) + ) + self.proj = nn.Sequential( + StdConv(in_channels, 128), + StdConv(128, 768) + ) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(768, 10, bias=False) + + def forward(self, x): + bs = x.size(0) + x = self.pooling(x) + x = self.proj(x) + x = self.avg_pool(x).view(bs, -1) + x = self.fc(x) + return x + + +class Cell(nn.Module): + def __init__(self, cell_name, prev_labels, channels): + super().__init__() + self.input_choice = mutables.InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True, + key=cell_name + "_input") + self.op_choice = mutables.LayerChoice([ + SepConvBN(channels, channels, 3, 1), + SepConvBN(channels, channels, 5, 2), + Pool("avg", 3, 1, 1), + Pool("max", 3, 1, 1), + nn.Identity() + ], key=cell_name + "_op") + + def forward(self, prev_layers): + chosen_input, chosen_mask = self.input_choice(prev_layers) + cell_out = self.op_choice(chosen_input) + return cell_out, chosen_mask + + +class Node(mutables.MutableScope): + def __init__(self, node_name, prev_node_names, channels): + super().__init__(node_name) + self.cell_x = Cell(node_name + "_x", prev_node_names, channels) + self.cell_y = Cell(node_name + "_y", prev_node_names, channels) + + def forward(self, prev_layers): + out_x, mask_x = self.cell_x(prev_layers) + out_y, mask_y = self.cell_y(prev_layers) + return out_x + out_y, mask_x | mask_y + + +class Calibration(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.process = None + if in_channels != out_channels: + self.process = StdConv(in_channels, out_channels) + + def forward(self, x): + if self.process is None: + return x + return self.process(x) + + +class ReductionLayer(nn.Module): + def __init__(self, in_channels_pp, in_channels_p, out_channels): + super().__init__() + self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False) + self.reduce1 = FactorizedReduce(in_channels_p, out_channels, affine=False) + + def forward(self, pprev, prev): + return self.reduce0(pprev), self.reduce1(prev) + + +class ENASLayer(nn.Module): + def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduction): + super().__init__() + self.preproc0 = Calibration(in_channels_pp, out_channels) + self.preproc1 = Calibration(in_channels_p, out_channels) + + self.num_nodes = num_nodes + name_prefix = "reduce" if reduction else "normal" + self.nodes = nn.ModuleList() + node_labels = [mutables.InputChoice.NO_KEY, mutables.InputChoice.NO_KEY] + for i in range(num_nodes): + node_labels.append("{}_node_{}".format(name_prefix, i)) + self.nodes.append(Node(node_labels[-1], node_labels[:-1], out_channels)) + self.final_conv_w = nn.Parameter(torch.zeros(out_channels, self.num_nodes + 2, out_channels, 1, 1), requires_grad=True) + self.bn = nn.BatchNorm2d(out_channels, affine=False) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_normal_(self.final_conv_w) + + def forward(self, pprev, prev): + pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev) + + prev_nodes_out = [pprev_, prev_] + nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device) + for i in range(self.num_nodes): + node_out, mask = self.nodes[i](prev_nodes_out) + nodes_used_mask[:mask.size(0)] |= mask + prev_nodes_out.append(node_out) + + unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1) + unused_nodes = F.relu(unused_nodes) + conv_weight = self.final_conv_w[:, ~nodes_used_mask, :, :, :] + conv_weight = conv_weight.view(conv_weight.size(0), -1, 1, 1) + out = F.conv2d(unused_nodes, conv_weight) + return prev, self.bn(out) + + +class MicroNetwork(nn.Module): + def __init__(self, num_layers=2, num_nodes=5, out_channels=24, in_channels=3, num_classes=10, + dropout_rate=0.0, use_aux_heads=False): + super().__init__() + self.num_layers = num_layers + self.use_aux_heads = use_aux_heads + + self.stem = nn.Sequential( + nn.Conv2d(in_channels, out_channels * 3, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_channels * 3) + ) + + pool_distance = self.num_layers // 3 + pool_layers = [pool_distance, 2 * pool_distance + 1] + self.dropout = nn.Dropout(dropout_rate) + + self.layers = nn.ModuleList() + c_pp = c_p = out_channels * 3 + c_cur = out_channels + for layer_id in range(self.num_layers + 2): + reduction = False + if layer_id in pool_layers: + c_cur, reduction = c_p * 2, True + self.layers.append(ReductionLayer(c_pp, c_p, c_cur)) + c_pp = c_p = c_cur + self.layers.append(ENASLayer(num_nodes, c_pp, c_p, c_cur, reduction)) + if self.use_aux_heads and layer_id == pool_layers[-1] + 1: + self.layers.append(AuxiliaryHead(c_cur, num_classes)) + c_pp, c_p = c_p, c_cur + + self.gap = nn.AdaptiveAvgPool2d(1) + self.dense = nn.Linear(c_cur, num_classes) + + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + + def forward(self, x): + bs = x.size(0) + prev = cur = self.stem(x) + aux_logits = None + + for layer in self.layers: + if isinstance(layer, AuxiliaryHead): + if self.training: + aux_logits = layer(cur) + else: + prev, cur = layer(prev, cur) + + cur = self.gap(F.relu(cur)).view(bs, -1) + cur = self.dropout(cur) + logits = self.dense(cur) + + if aux_logits is not None: + return logits, aux_logits + return logits diff --git a/examples/nas/enas/ops.py b/examples/nas/enas/ops.py new file mode 100644 index 0000000000..2b9df8069b --- /dev/null +++ b/examples/nas/enas/ops.py @@ -0,0 +1,103 @@ +import torch +import torch.nn as nn + + +class StdConv(nn.Module): + def __init__(self, C_in, C_out): + super(StdConv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(C_out, affine=False), + nn.ReLU() + ) + + def forward(self, x): + return self.conv(x) + + +class PoolBranch(nn.Module): + def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False): + super().__init__() + self.preproc = StdConv(C_in, C_out) + self.pool = Pool(pool_type, kernel_size, stride, padding) + self.bn = nn.BatchNorm2d(C_out, affine=affine) + + def forward(self, x): + out = self.preproc(x) + out = self.pool(out) + out = self.bn(out) + return out + + +class SeparableConv(nn.Module): + def __init__(self, C_in, C_out, kernel_size, stride, padding): + super(SeparableConv, self).__init__() + self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride, + groups=C_in, bias=False) + self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False) + + def forward(self, x): + out = self.depthwise(x) + out = self.pointwise(out) + return out + + +class ConvBranch(nn.Module): + def __init__(self, C_in, C_out, kernel_size, stride, padding, separable): + super(ConvBranch, self).__init__() + self.preproc = StdConv(C_in, C_out) + if separable: + self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding) + else: + self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding) + self.postproc = nn.Sequential( + nn.BatchNorm2d(C_out, affine=False), + nn.ReLU() + ) + + def forward(self, x): + out = self.preproc(x) + out = self.conv(out) + out = self.postproc(out) + return out + + +class FactorizedReduce(nn.Module): + def __init__(self, C_in, C_out, affine=False): + super().__init__() + self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.bn = nn.BatchNorm2d(C_out, affine=affine) + + def forward(self, x): + out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) + out = self.bn(out) + return out + + +class Pool(nn.Module): + def __init__(self, pool_type, kernel_size, stride, padding): + super().__init__() + if pool_type.lower() == 'max': + self.pool = nn.MaxPool2d(kernel_size, stride, padding) + elif pool_type.lower() == 'avg': + self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) + else: + raise ValueError() + + def forward(self, x): + return self.pool(x) + + +class SepConvBN(nn.Module): + def __init__(self, C_in, C_out, kernel_size, padding): + super().__init__() + self.relu = nn.ReLU() + self.conv = SeparableConv(C_in, C_out, kernel_size, 1, padding) + self.bn = nn.BatchNorm2d(C_out, affine=True) + + def forward(self, x): + x = self.relu(x) + x = self.conv(x) + x = self.bn(x) + return x diff --git a/examples/nas/enas/search.py b/examples/nas/enas/search.py new file mode 100644 index 0000000000..5188bfeb9f --- /dev/null +++ b/examples/nas/enas/search.py @@ -0,0 +1,61 @@ +import logging +import time +from argparse import ArgumentParser + +import torch +import torch.nn as nn + +import datasets +from macro import GeneralNetwork +from micro import MicroNetwork +from nni.nas.pytorch import enas +from nni.nas.pytorch.callbacks import LRSchedulerCallback, ArchitectureCheckpoint +from utils import accuracy, reward_accuracy + +logger = logging.getLogger() + +fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' +logging.Formatter.converter = time.localtime +formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p') + +std_out_info = logging.StreamHandler() +std_out_info.setFormatter(formatter) +logger.setLevel(logging.INFO) +logger.addHandler(std_out_info) + +if __name__ == "__main__": + parser = ArgumentParser("enas") + parser.add_argument("--batch-size", default=128, type=int) + parser.add_argument("--log-frequency", default=10, type=int) + parser.add_argument("--search-for", choices=["macro", "micro"], default="macro") + args = parser.parse_args() + + dataset_train, dataset_valid = datasets.get_dataset("cifar10") + if args.search_for == "macro": + model = GeneralNetwork() + num_epochs = 310 + mutator = None + elif args.search_for == "micro": + model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) + num_epochs = 150 + mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True) + else: + raise AssertionError + + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001) + + trainer = enas.EnasTrainer(model, + loss=criterion, + metrics=accuracy, + reward_function=reward_accuracy, + optimizer=optimizer, + callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")], + batch_size=args.batch_size, + num_epochs=num_epochs, + dataset_train=dataset_train, + dataset_valid=dataset_valid, + log_frequency=args.log_frequency, + mutator=mutator) + trainer.train() diff --git a/examples/nas/enas/utils.py b/examples/nas/enas/utils.py new file mode 100644 index 0000000000..22bc62819f --- /dev/null +++ b/examples/nas/enas/utils.py @@ -0,0 +1,27 @@ +import torch + + +def accuracy(output, target, topk=(1,)): + """ Computes the precision@k for the specified values of k """ + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + # one-hot case + if target.ndimension() > 1: + target = target.max(1)[1] + + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = dict() + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() + return res + + +def reward_accuracy(output, target, topk=(1,)): + batch_size = target.size(0) + _, predicted = torch.max(output.data, 1) + return (predicted == target).sum().item() / batch_size diff --git a/examples/nas/pdarts/.gitignore b/examples/nas/pdarts/.gitignore new file mode 100644 index 0000000000..054c274eeb --- /dev/null +++ b/examples/nas/pdarts/.gitignore @@ -0,0 +1,2 @@ +data/* +log diff --git a/examples/nas/pdarts/search.py b/examples/nas/pdarts/search.py new file mode 100644 index 0000000000..5d38fda0db --- /dev/null +++ b/examples/nas/pdarts/search.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +import sys +import time +from argparse import ArgumentParser + +import torch +import torch.nn as nn + +from nni.nas.pytorch.callbacks import ArchitectureCheckpoint +from nni.nas.pytorch.pdarts import PdartsTrainer + +# prevent it to be reordered. +if True: + sys.path.append('../darts') + from utils import accuracy + from model import CNN + import datasets + +logger = logging.getLogger() + +fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' +logging.Formatter.converter = time.localtime +formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p') + +std_out_info = logging.StreamHandler() +std_out_info.setFormatter(formatter) +logger.setLevel(logging.INFO) +logger.addHandler(std_out_info) + +if __name__ == "__main__": + parser = ArgumentParser("pdarts") + parser.add_argument('--add_layers', action='append', + default=[0, 6, 12], help='add layers') + parser.add_argument("--nodes", default=4, type=int) + parser.add_argument("--layers", default=5, type=int) + parser.add_argument("--batch-size", default=64, type=int) + parser.add_argument("--log-frequency", default=1, type=int) + parser.add_argument("--epochs", default=50, type=int) + args = parser.parse_args() + + logger.info("loading data") + dataset_train, dataset_valid = datasets.get_dataset("cifar10") + + def model_creator(layers): + model = CNN(32, 3, 16, 10, layers, n_nodes=args.nodes) + criterion = nn.CrossEntropyLoss() + + optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001) + + return model, criterion, optim, lr_scheduler + + logger.info("initializing trainer") + trainer = PdartsTrainer(model_creator, + layers=args.layers, + metrics=lambda output, target: accuracy(output, target, topk=(1,)), + pdarts_num_layers=[0, 6, 12], + pdarts_num_to_drop=[3, 2, 2], + num_epochs=args.epochs, + dataset_train=dataset_train, + dataset_valid=dataset_valid, + batch_size=args.batch_size, + log_frequency=args.log_frequency, + callbacks=[ArchitectureCheckpoint("./checkpoints")]) + logger.info("training") + trainer.train() diff --git a/src/sdk/__init__.py b/src/sdk/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sdk/pynni/__init__.py b/src/sdk/pynni/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sdk/pynni/nni/nas/__init__.py b/src/sdk/pynni/nni/nas/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sdk/pynni/nni/nas/pytorch/__init__.py b/src/sdk/pynni/nni/nas/pytorch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sdk/pynni/nni/nas/pytorch/base_mutator.py b/src/sdk/pynni/nni/nas/pytorch/base_mutator.py new file mode 100644 index 0000000000..df45a869af --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/base_mutator.py @@ -0,0 +1,124 @@ +import logging + +import torch.nn as nn +from nni.nas.pytorch.mutables import Mutable, MutableScope, InputChoice +from nni.nas.pytorch.utils import StructuredMutableTreeNode + +logger = logging.getLogger(__name__) + + +class BaseMutator(nn.Module): + """ + A mutator is responsible for mutating a graph by obtaining the search space from the network and implementing + callbacks that are called in ``forward`` in Mutables. + """ + + def __init__(self, model): + super().__init__() + self.__dict__["model"] = model + self._structured_mutables = self._parse_search_space(self.model) + + def _parse_search_space(self, module, root=None, prefix="", memo=None, nested_detection=None): + if memo is None: + memo = set() + if root is None: + root = StructuredMutableTreeNode(None) + if module not in memo: + memo.add(module) + if isinstance(module, Mutable): + if nested_detection is not None: + raise RuntimeError("Cannot have nested search space. Error at {} in {}" + .format(module, nested_detection)) + module.name = prefix + module.set_mutator(self) + root = root.add_child(module) + if not isinstance(module, MutableScope): + nested_detection = module + if isinstance(module, InputChoice): + for k in module.choose_from: + if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]: + raise RuntimeError("'{}' required by '{}' not found in keys that appeared before, and is not NO_KEY." + .format(k, module.key)) + for name, submodule in module._modules.items(): + if submodule is None: + continue + submodule_prefix = prefix + ("." if prefix else "") + name + self._parse_search_space(submodule, root, submodule_prefix, memo=memo, + nested_detection=nested_detection) + return root + + @property + def mutables(self): + return self._structured_mutables + + def forward(self, *inputs): + raise RuntimeError("Forward is undefined for mutators.") + + def __setattr__(self, name, value): + if name == "model": + raise AttributeError("Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to " + "include you network, as it will include all parameters in model into the mutator.") + return super().__setattr__(name, value) + + def enter_mutable_scope(self, mutable_scope): + """ + Callback when forward of a MutableScope is entered. + + Parameters + ---------- + mutable_scope : MutableScope + """ + pass + + def exit_mutable_scope(self, mutable_scope): + """ + Callback when forward of a MutableScope is exited. + + Parameters + ---------- + mutable_scope : MutableScope + """ + pass + + def on_forward_layer_choice(self, mutable, *inputs): + """ + Callbacks of forward in LayerChoice. + + Parameters + ---------- + mutable : LayerChoice + inputs : list of torch.Tensor + + Returns + ------- + tuple of torch.Tensor and torch.Tensor + output tensor and mask + """ + raise NotImplementedError + + def on_forward_input_choice(self, mutable, tensor_list): + """ + Callbacks of forward in InputChoice. + + Parameters + ---------- + mutable : InputChoice + tensor_list : list of torch.Tensor + + Returns + ------- + tuple of torch.Tensor and torch.Tensor + output tensor and mask + """ + raise NotImplementedError + + def export(self): + """ + Export the data of all decisions. This should output the decisions of all the mutables, so that the whole + network can be fully determined with these decisions for further training from scratch. + + Returns + ------- + dict + """ + raise NotImplementedError diff --git a/src/sdk/pynni/nni/nas/pytorch/base_trainer.py b/src/sdk/pynni/nni/nas/pytorch/base_trainer.py new file mode 100644 index 0000000000..db1b033073 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/base_trainer.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + + +class BaseTrainer(ABC): + + @abstractmethod + def train(self): + raise NotImplementedError + + @abstractmethod + def validate(self): + raise NotImplementedError + + @abstractmethod + def export(self, file): + raise NotImplementedError + + @abstractmethod + def checkpoint(self): + raise NotImplementedError diff --git a/src/sdk/pynni/nni/nas/pytorch/callbacks.py b/src/sdk/pynni/nni/nas/pytorch/callbacks.py new file mode 100644 index 0000000000..817c1eb3f4 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/callbacks.py @@ -0,0 +1,51 @@ +import logging +import os + +_logger = logging.getLogger(__name__) + + +class Callback: + + def __init__(self): + self.model = None + self.mutator = None + self.trainer = None + + def build(self, model, mutator, trainer): + self.model = model + self.mutator = mutator + self.trainer = trainer + + def on_epoch_begin(self, epoch): + pass + + def on_epoch_end(self, epoch): + pass + + def on_batch_begin(self, epoch): + pass + + def on_batch_end(self, epoch): + pass + + +class LRSchedulerCallback(Callback): + def __init__(self, scheduler, mode="epoch"): + super().__init__() + assert mode == "epoch" + self.scheduler = scheduler + self.mode = mode + + def on_epoch_end(self, epoch): + self.scheduler.step() + + +class ArchitectureCheckpoint(Callback): + def __init__(self, checkpoint_dir, every="epoch"): + super().__init__() + assert every == "epoch" + self.checkpoint_dir = checkpoint_dir + os.makedirs(self.checkpoint_dir, exist_ok=True) + + def on_epoch_end(self, epoch): + self.trainer.export(os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch))) diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py b/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py new file mode 100644 index 0000000000..3bf08d285c --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py @@ -0,0 +1,2 @@ +from .mutator import DartsMutator +from .trainer import DartsTrainer \ No newline at end of file diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py new file mode 100644 index 0000000000..9674e2b954 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nni.nas.pytorch.mutator import Mutator +from nni.nas.pytorch.mutables import LayerChoice, InputChoice + + +class DartsMutator(Mutator): + def __init__(self, model): + super().__init__(model) + self.choices = nn.ParameterDict() + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1)) + + def device(self): + for v in self.choices.values(): + return v.device + + def sample_search(self): + result = dict() + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1] + elif isinstance(mutable, InputChoice): + result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) + return result + + def sample_final(self): + result = dict() + edges_max = dict() + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0) + edges_max[mutable.key] = max_val + result[mutable.key] = F.one_hot(index, num_classes=mutable.length).view(-1).bool() + for mutable in self.mutables: + if isinstance(mutable, InputChoice): + weights = torch.tensor([edges_max.get(src_key, 0.) for src_key in mutable.choose_from]) # pylint: disable=not-callable + _, topk_edge_indices = torch.topk(weights, mutable.n_chosen or mutable.n_candidates) + selected_multihot = [] + for i, src_key in enumerate(mutable.choose_from): + if i not in topk_edge_indices and src_key in result: + result[src_key] = torch.zeros_like(result[src_key]) # clear this choice to optimize calc graph + selected_multihot.append(i in topk_edge_indices) + result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable + return result diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py new file mode 100644 index 0000000000..772d455e72 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py @@ -0,0 +1,173 @@ +import copy +import logging + +import torch +import torch.nn as nn +from nni.nas.pytorch.trainer import Trainer +from nni.nas.pytorch.utils import AverageMeterGroup + +from .mutator import DartsMutator + +logger = logging.getLogger(__name__) + + +class DartsTrainer(Trainer): + def __init__(self, model, loss, metrics, + optimizer, num_epochs, dataset_train, dataset_valid, + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, + callbacks=None, arc_learning_rate=3.0E-4, unrolled=True): + super().__init__(model, mutator if mutator is not None else DartsMutator(model), + loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, + batch_size, workers, device, log_frequency, callbacks) + self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arc_learning_rate, betas=(0.5, 0.999), + weight_decay=1.0E-3) + self.unrolled = unrolled + + n_train = len(self.dataset_train) + split = n_train // 2 + indices = list(range(n_train)) + train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]) + valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]) + self.train_loader = torch.utils.data.DataLoader(self.dataset_train, + batch_size=batch_size, + sampler=train_sampler, + num_workers=workers) + self.valid_loader = torch.utils.data.DataLoader(self.dataset_train, + batch_size=batch_size, + sampler=valid_sampler, + num_workers=workers) + self.test_loader = torch.utils.data.DataLoader(self.dataset_valid, + batch_size=batch_size, + num_workers=workers) + + def train_one_epoch(self, epoch): + self.model.train() + self.mutator.train() + meters = AverageMeterGroup() + for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)): + trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device) + val_X, val_y = val_X.to(self.device), val_y.to(self.device) + + # phase 1. architecture step + self.ctrl_optim.zero_grad() + if self.unrolled: + self._unrolled_backward(trn_X, trn_y, val_X, val_y) + else: + self._backward(val_X, val_y) + self.ctrl_optim.step() + + # phase 2: child network step + self.optimizer.zero_grad() + logits, loss = self._logits_and_loss(trn_X, trn_y) + loss.backward() + nn.utils.clip_grad_norm_(self.model.parameters(), 5.) # gradient clipping + self.optimizer.step() + + metrics = self.metrics(logits, trn_y) + metrics["loss"] = loss.item() + meters.update(metrics) + if self.log_frequency is not None and step % self.log_frequency == 0: + logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, + self.num_epochs, step + 1, len(self.train_loader), meters) + + def validate_one_epoch(self, epoch): + self.model.eval() + self.mutator.eval() + meters = AverageMeterGroup() + with torch.no_grad(): + self.mutator.reset() + for step, (X, y) in enumerate(self.test_loader): + X, y = X.to(self.device), y.to(self.device) + logits = self.model(X) + metrics = self.metrics(logits, y) + meters.update(metrics) + if self.log_frequency is not None and step % self.log_frequency == 0: + logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, + self.num_epochs, step + 1, len(self.test_loader), meters) + + def _logits_and_loss(self, X, y): + self.mutator.reset() + logits = self.model(X) + loss = self.loss(logits, y) + return logits, loss + + def _backward(self, val_X, val_y): + """ + Simple backward with gradient descent + """ + _, loss = self._logits_and_loss(val_X, val_y) + loss.backward() + + def _unrolled_backward(self, trn_X, trn_y, val_X, val_y): + """ + Compute unrolled loss and backward its gradients + """ + backup_params = copy.deepcopy(tuple(self.model.parameters())) + + # do virtual step on training data + lr = self.optimizer.param_groups[0]["lr"] + momentum = self.optimizer.param_groups[0]["momentum"] + weight_decay = self.optimizer.param_groups[0]["weight_decay"] + self._compute_virtual_model(trn_X, trn_y, lr, momentum, weight_decay) + + # calculate unrolled loss on validation data + # keep gradients for model here for compute hessian + _, loss = self._logits_and_loss(val_X, val_y) + w_model, w_ctrl = tuple(self.model.parameters()), tuple(self.mutator.parameters()) + w_grads = torch.autograd.grad(loss, w_model + w_ctrl) + d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):] + + # compute hessian and final gradients + hessian = self._compute_hessian(backup_params, d_model, trn_X, trn_y) + with torch.no_grad(): + for param, d, h in zip(w_ctrl, d_ctrl, hessian): + # gradient = dalpha - lr * hessian + param.grad = d - lr * h + + # restore weights + self._restore_weights(backup_params) + + def _compute_virtual_model(self, X, y, lr, momentum, weight_decay): + """ + Compute unrolled weights w` + """ + # don't need zero_grad, using autograd to calculate gradients + _, loss = self._logits_and_loss(X, y) + gradients = torch.autograd.grad(loss, self.model.parameters()) + with torch.no_grad(): + for w, g in zip(self.model.parameters(), gradients): + m = self.optimizer.state[w].get("momentum_buffer", 0.) + w = w - lr * (momentum * m + g + weight_decay * w) + + def _restore_weights(self, backup_params): + with torch.no_grad(): + for param, backup in zip(self.model.parameters(), backup_params): + param.copy_(backup) + + def _compute_hessian(self, backup_params, dw, trn_X, trn_y): + """ + dw = dw` { L_val(w`, alpha) } + w+ = w + eps * dw + w- = w - eps * dw + hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps) + eps = 0.01 / ||dw|| + """ + self._restore_weights(backup_params) + norm = torch.cat([w.view(-1) for w in dw]).norm() + eps = 0.01 / norm + if norm < 1E-8: + logger.warning("In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.", norm.item()) + + dalphas = [] + for e in [eps, -2. * eps]: + # w+ = w + eps*dw`, w- = w - eps*dw` + with torch.no_grad(): + for p, d in zip(self.model.parameters(), dw): + p += e * d + + _, loss = self._logits_and_loss(trn_X, trn_y) + dalphas.append(torch.autograd.grad(loss, self.mutator.parameters())) + + dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) } + hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)] + return hessian diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/__init__.py b/src/sdk/pynni/nni/nas/pytorch/enas/__init__.py new file mode 100644 index 0000000000..78f066ff6c --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/enas/__init__.py @@ -0,0 +1,2 @@ +from .mutator import EnasMutator +from .trainer import EnasTrainer diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py new file mode 100644 index 0000000000..7a1a6f80af --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nni.nas.pytorch.mutator import Mutator +from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope + + +class StackedLSTMCell(nn.Module): + def __init__(self, layers, size, bias): + super().__init__() + self.lstm_num_layers = layers + self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias) + for _ in range(self.lstm_num_layers)]) + + def forward(self, inputs, hidden): + prev_c, prev_h = hidden + next_c, next_h = [], [] + for i, m in enumerate(self.lstm_modules): + curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i])) + next_c.append(curr_c) + next_h.append(curr_h) + inputs = curr_h[-1] + return next_c, next_h + + +class EnasMutator(Mutator): + + def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False, + skip_target=0.4, branch_bias=0.25): + super().__init__(model) + self.lstm_size = lstm_size + self.lstm_num_layers = lstm_num_layers + self.tanh_constant = tanh_constant + self.cell_exit_extra_step = cell_exit_extra_step + self.skip_target = skip_target + self.branch_bias = branch_bias + + self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False) + self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False) + self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) + self.v_attn = nn.Linear(self.lstm_size, 1, bias=False) + self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1) + self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable + self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none") + self.bias_dict = nn.ParameterDict() + + self.max_layer_choice = 0 + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + if self.max_layer_choice == 0: + self.max_layer_choice = mutable.length + assert self.max_layer_choice == mutable.length, \ + "ENAS mutator requires all layer choice have the same number of candidates." + # We are judging by keys and module types to add biases to layer choices. Needs refactor. + if "reduce" in mutable.key: + def is_conv(choice): + return "conv" in str(type(choice)).lower() + bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable + for choice in mutable.choices]) + self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False) + + self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size) + self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False) + + def sample_search(self): + self._initialize() + self._sample(self.mutables) + return self._choices + + def sample_final(self): + return self.sample_search() + + def _sample(self, tree): + mutable = tree.mutable + if isinstance(mutable, LayerChoice) and mutable.key not in self._choices: + self._choices[mutable.key] = self._sample_layer_choice(mutable) + elif isinstance(mutable, InputChoice) and mutable.key not in self._choices: + self._choices[mutable.key] = self._sample_input_choice(mutable) + for child in tree.children: + self._sample(child) + if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid: + if self.cell_exit_extra_step: + self._lstm_next_step() + self._mark_anchor(mutable.key) + + def _initialize(self): + self._choices = dict() + self._anchors_hid = dict() + self._inputs = self.g_emb.data + self._c = [torch.zeros((1, self.lstm_size), + dtype=self._inputs.dtype, + device=self._inputs.device) for _ in range(self.lstm_num_layers)] + self._h = [torch.zeros((1, self.lstm_size), + dtype=self._inputs.dtype, + device=self._inputs.device) for _ in range(self.lstm_num_layers)] + self.sample_log_prob = 0 + self.sample_entropy = 0 + self.sample_skip_penalty = 0 + + def _lstm_next_step(self): + self._c, self._h = self.lstm(self._inputs, (self._c, self._h)) + + def _mark_anchor(self, key): + self._anchors_hid[key] = self._h[-1] + + def _sample_layer_choice(self, mutable): + self._lstm_next_step() + logit = self.soft(self._h[-1]) + if self.tanh_constant is not None: + logit = self.tanh_constant * torch.tanh(logit) + if mutable.key in self.bias_dict: + logit += self.bias_dict[mutable.key] + branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) + log_prob = self.cross_entropy_loss(logit, branch_id) + self.sample_log_prob += torch.sum(log_prob) + entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type + self.sample_entropy += torch.sum(entropy) + self._inputs = self.embedding(branch_id) + return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1) + + def _sample_input_choice(self, mutable): + query, anchors = [], [] + for label in mutable.choose_from: + if label not in self._anchors_hid: + self._lstm_next_step() + self._mark_anchor(label) # empty loop, fill not found + query.append(self.attn_anchor(self._anchors_hid[label])) + anchors.append(self._anchors_hid[label]) + query = torch.cat(query, 0) + query = torch.tanh(query + self.attn_query(self._h[-1])) + query = self.v_attn(query) + if self.tanh_constant is not None: + query = self.tanh_constant * torch.tanh(query) + + if mutable.n_chosen is None: + logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type + + skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) + skip_prob = torch.sigmoid(logit) + kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets)) + self.sample_skip_penalty += kl + log_prob = self.cross_entropy_loss(logit, skip) + self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0) + else: + assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS." + logit = query.view(1, -1) + index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) + skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1) + log_prob = self.cross_entropy_loss(logit, index) + self._inputs = anchors[index.item()] + + self.sample_log_prob += torch.sum(log_prob) + entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type + self.sample_entropy += torch.sum(entropy) + return skip.bool() diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py b/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py new file mode 100644 index 0000000000..d2074b94bc --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py @@ -0,0 +1,122 @@ +import logging +import torch +import torch.optim as optim + +from nni.nas.pytorch.trainer import Trainer +from nni.nas.pytorch.utils import AverageMeterGroup +from .mutator import EnasMutator + +logger = logging.getLogger(__name__) + + +class EnasTrainer(Trainer): + def __init__(self, model, loss, metrics, reward_function, + optimizer, num_epochs, dataset_train, dataset_valid, + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, + entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, + mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4): + super().__init__(model, mutator if mutator is not None else EnasMutator(model), + loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, + batch_size, workers, device, log_frequency, callbacks) + self.reward_function = reward_function + self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr) + + self.entropy_weight = entropy_weight + self.skip_weight = skip_weight + self.baseline_decay = baseline_decay + self.baseline = 0. + self.mutator_steps_aggregate = mutator_steps_aggregate + self.mutator_steps = mutator_steps + self.aux_weight = aux_weight + + n_train = len(self.dataset_train) + split = n_train // 10 + indices = list(range(n_train)) + train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split]) + valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:]) + self.train_loader = torch.utils.data.DataLoader(self.dataset_train, + batch_size=batch_size, + sampler=train_sampler, + num_workers=workers) + self.valid_loader = torch.utils.data.DataLoader(self.dataset_train, + batch_size=batch_size, + sampler=valid_sampler, + num_workers=workers) + self.test_loader = torch.utils.data.DataLoader(self.dataset_valid, + batch_size=batch_size, + num_workers=workers) + + def train_one_epoch(self, epoch): + # Sample model and train + self.model.train() + self.mutator.eval() + meters = AverageMeterGroup() + for step, (x, y) in enumerate(self.train_loader): + x, y = x.to(self.device), y.to(self.device) + self.optimizer.zero_grad() + + with torch.no_grad(): + self.mutator.reset() + logits = self.model(x) + + if isinstance(logits, tuple): + logits, aux_logits = logits + aux_loss = self.loss(aux_logits, y) + else: + aux_loss = 0. + metrics = self.metrics(logits, y) + loss = self.loss(logits, y) + loss = loss + self.aux_weight * aux_loss + loss.backward() + self.optimizer.step() + metrics["loss"] = loss.item() + meters.update(metrics) + + if self.log_frequency is not None and step % self.log_frequency == 0: + logger.info("Model Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, + self.num_epochs, step + 1, len(self.train_loader), meters) + + # Train sampler (mutator) + self.model.eval() + self.mutator.train() + meters = AverageMeterGroup() + mutator_step, total_mutator_steps = 0, self.mutator_steps * self.mutator_steps_aggregate + while mutator_step < total_mutator_steps: + for step, (x, y) in enumerate(self.valid_loader): + x, y = x.to(self.device), y.to(self.device) + + self.mutator.reset() + with torch.no_grad(): + logits = self.model(x) + metrics = self.metrics(logits, y) + reward = self.reward_function(logits, y) + if self.entropy_weight is not None: + reward += self.entropy_weight * self.mutator.sample_entropy + self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) + self.baseline = self.baseline.detach().item() + loss = self.mutator.sample_log_prob * (reward - self.baseline) + if self.skip_weight: + loss += self.skip_weight * self.mutator.sample_skip_penalty + metrics["reward"] = reward + metrics["loss"] = loss.item() + metrics["ent"] = self.mutator.sample_entropy.item() + metrics["baseline"] = self.baseline + metrics["skip"] = self.mutator.sample_skip_penalty + + loss = loss / self.mutator_steps_aggregate + loss.backward() + meters.update(metrics) + + if mutator_step % self.mutator_steps_aggregate == 0: + self.mutator_optim.step() + self.mutator_optim.zero_grad() + + if self.log_frequency is not None and step % self.log_frequency == 0: + logger.info("RL Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs, + mutator_step // self.mutator_steps_aggregate + 1, self.mutator_steps, meters) + mutator_step += 1 + if mutator_step >= total_mutator_steps: + break + + def validate_one_epoch(self, epoch): + pass diff --git a/src/sdk/pynni/nni/nas/pytorch/fixed.py b/src/sdk/pynni/nni/nas/pytorch/fixed.py new file mode 100644 index 0000000000..d953e1cd5e --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/fixed.py @@ -0,0 +1,79 @@ +import json + +import torch + +from nni.nas.pytorch.mutables import MutableScope +from nni.nas.pytorch.mutator import Mutator + + +class FixedArchitecture(Mutator): + + def __init__(self, model, fixed_arc, strict=True): + """ + Initialize a fixed architecture mutator. + + Parameters + ---------- + model : nn.Module + A mutable network. + fixed_arc : str or dict + Path to the architecture checkpoint (a string), or preloaded architecture object (a dict). + strict : bool + Force everything that appears in `fixed_arc` to be used at least once. + """ + super().__init__(model) + self._fixed_arc = fixed_arc + + mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)]) + fixed_arc_keys = set(self._fixed_arc.keys()) + if fixed_arc_keys - mutable_keys: + raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys)) + if mutable_keys - fixed_arc_keys: + raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys)) + + def sample_search(self): + return self._fixed_arc + + def sample_final(self): + return self._fixed_arc + + +def _encode_tensor(data, device): + if isinstance(data, list): + if all(map(lambda o: isinstance(o, bool), data)): + return torch.tensor(data, dtype=torch.bool, device=device) # pylint: disable=not-callable + else: + return torch.tensor(data, dtype=torch.float, device=device) # pylint: disable=not-callable + if isinstance(data, dict): + return {k: _encode_tensor(v, device) for k, v in data.items()} + return data + + +def apply_fixed_architecture(model, fixed_arc_path, device=None): + """ + Load architecture from `fixed_arc_path` and apply to model. + + Parameters + ---------- + model : torch.nn.Module + Model with mutables. + fixed_arc_path : str + Path to the JSON that stores the architecture. + device : torch.device + Architecture weights will be transfered to `device`. + + Returns + ------- + FixedArchitecture + """ + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if isinstance(fixed_arc_path, str): + with open(fixed_arc_path, "r") as f: + fixed_arc = json.load(f) + fixed_arc = _encode_tensor(fixed_arc, device) + architecture = FixedArchitecture(model, fixed_arc) + architecture.to(device) + architecture.reset() + return architecture diff --git a/src/sdk/pynni/nni/nas/pytorch/mutables.py b/src/sdk/pynni/nni/nas/pytorch/mutables.py new file mode 100644 index 0000000000..115578a4ae --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/mutables.py @@ -0,0 +1,188 @@ +import logging + +import torch.nn as nn + +from nni.nas.pytorch.utils import global_mutable_counting + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class Mutable(nn.Module): + """ + Mutable is designed to function as a normal layer, with all necessary operators' weights. + States and weights of architectures should be included in mutator, instead of the layer itself. + + Mutable has a key, which marks the identity of the mutable. This key can be used by users to share + decisions among different mutables. In mutator's implementation, mutators should use the key to + distinguish different mutables. Mutables that share the same key should be "similar" to each other. + + Currently the default scope for keys is global. + """ + + def __init__(self, key=None): + super().__init__() + if key is not None: + if not isinstance(key, str): + key = str(key) + logger.warning("Warning: key \"%s\" is not string, converted to string.", key) + self._key = key + else: + self._key = self.__class__.__name__ + str(global_mutable_counting()) + self.init_hook = self.forward_hook = None + + def __deepcopy__(self, memodict=None): + raise NotImplementedError("Deep copy doesn't work for mutables.") + + def __call__(self, *args, **kwargs): + self._check_built() + return super().__call__(*args, **kwargs) + + def set_mutator(self, mutator): + if "mutator" in self.__dict__: + raise RuntimeError("`set_mutator` is called more than once. Did you parse the search space multiple times? " + "Or did you apply multiple fixed architectures?") + self.__dict__["mutator"] = mutator + + def forward(self, *inputs): + raise NotImplementedError + + @property + def key(self): + return self._key + + @property + def name(self): + return self._name if hasattr(self, "_name") else "_key" + + @name.setter + def name(self, name): + self._name = name + + def _check_built(self): + if not hasattr(self, "mutator"): + raise ValueError( + "Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__" + "so that trainer can locate all your mutables. See NNI docs for more details.".format(self)) + + def __repr__(self): + return "{} ({})".format(self.name, self.key) + + +class MutableScope(Mutable): + """ + Mutable scope marks a subgraph/submodule to help mutators make better decisions. + Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope`` + and ``exit_mutable_scope`` to catch corresponding events, and do status dump or update. + MutableScope are also mutables that are listed in the mutables (search space). + """ + + def __init__(self, key): + super().__init__(key=key) + + def __call__(self, *args, **kwargs): + try: + self._check_built() + self.mutator.enter_mutable_scope(self) + return super().__call__(*args, **kwargs) + finally: + self.mutator.exit_mutable_scope(self) + + +class LayerChoice(Mutable): + def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None): + super().__init__(key=key) + self.length = len(op_candidates) + self.choices = nn.ModuleList(op_candidates) + self.reduction = reduction + self.return_mask = return_mask + + def forward(self, *inputs): + out, mask = self.mutator.on_forward_layer_choice(self, *inputs) + if self.return_mask: + return out, mask + return out + + +class InputChoice(Mutable): + """ + Input choice selects `n_chosen` inputs from `choose_from` (contains `n_candidates` keys). For beginners, + use `n_candidates` instead of `choose_from` is a safe option. To get the most power out of it, you might want to + know about `choose_from`. + + The keys in `choose_from` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones. + The keys are designed to be the keys of the sources. To help mutators make better decisions, + mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the + output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g., + ``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a + module/submodule, it needs to be annotated with a key: that's where a ``MutableScope`` is needed. + """ + + NO_KEY = "" + + def __init__(self, n_candidates=None, choose_from=None, n_chosen=None, + reduction="sum", return_mask=False, key=None): + """ + Initialization. + + Parameters + ---------- + n_candidates : int + Number of inputs to choose from. + choose_from : list of str + List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled. + If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates` + number of empty string. + n_chosen : int + Recommended inputs to choose. If None, mutator is instructed to select any. + reduction : str + `mean`, `concat`, `sum` or `none`. + return_mask : bool + If `return_mask`, return output tensor and a mask. Otherwise return tensor only. + key : str + Key of the input choice. + """ + super().__init__(key=key) + # precondition check + assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \ + "must be not None." + if choose_from is not None and n_candidates is None: + n_candidates = len(choose_from) + elif choose_from is None and n_candidates is not None: + choose_from = [self.NO_KEY] * n_candidates + assert n_candidates == len(choose_from), "Number of candidates must be equal to the length of `choose_from`." + assert n_candidates > 0, "Number of candidates must be greater than 0." + assert n_chosen is None or 0 <= n_chosen <= n_candidates, "Expected selected number must be None or no more " \ + "than number of candidates." + + self.n_candidates = n_candidates + self.choose_from = choose_from + self.n_chosen = n_chosen + self.reduction = reduction + self.return_mask = return_mask + + def forward(self, optional_inputs): + """ + Forward method of LayerChoice. + + Parameters + ---------- + optional_inputs : list or dict + Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of + `choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as + `choose_from`. + + Returns + ------- + tuple of torch.Tensor and torch.Tensor or torch.Tensor + """ + optional_input_list = optional_inputs + if isinstance(optional_inputs, dict): + optional_input_list = [optional_inputs[tag] for tag in self.choose_from] + assert isinstance(optional_input_list, list), "Optional input list must be a list" + assert len(optional_inputs) == self.n_candidates, \ + "Length of the input list must be equal to number of candidates." + out, mask = self.mutator.on_forward_input_choice(self, optional_input_list) + if self.return_mask: + return out, mask + return out diff --git a/src/sdk/pynni/nni/nas/pytorch/mutator.py b/src/sdk/pynni/nni/nas/pytorch/mutator.py new file mode 100644 index 0000000000..a0a22b2649 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/mutator.py @@ -0,0 +1,146 @@ +import logging + +import torch + +from nni.nas.pytorch.base_mutator import BaseMutator + +logger = logging.getLogger(__name__) + + +class Mutator(BaseMutator): + + def __init__(self, model): + super().__init__(model) + self._cache = dict() + + def sample_search(self): + """ + Override to implement this method to iterate over mutables and make decisions. + + Returns + ------- + dict + A mapping from key of mutables to decisions. + """ + raise NotImplementedError + + def sample_final(self): + """ + Override to implement this method to iterate over mutables and make decisions that is final + for export and retraining. + + Returns + ------- + dict + A mapping from key of mutables to decisions. + """ + raise NotImplementedError + + def reset(self): + """ + Reset the mutator by call the `sample_search` to resample (for search). + + Returns + ------- + None + """ + self._cache = self.sample_search() + + def export(self): + """ + Resample (for final) and return results. + + Returns + ------- + dict + """ + return self.sample_final() + + def on_forward_layer_choice(self, mutable, *inputs): + """ + On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers + (either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified + in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`. + + Parameters + ---------- + mutable : LayerChoice + inputs : list of torch.Tensor + + Returns + ------- + tuple of torch.Tensor and torch.Tensor + """ + + def _map_fn(op, *inputs): + return op(*inputs) + + mask = self._get_decision(mutable) + assert len(mask) == len(mutable.choices) + out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask) + return self._tensor_reduction(mutable.reduction, out), mask + + def on_forward_input_choice(self, mutable, tensor_list): + """ + On default, this method calls :meth:`on_calc_input_choice_mask` with `tags` + to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce + the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the + mask with corresponding `mutable.key`. + + Parameters + ---------- + mutable : InputChoice + tensor_list : list of torch.Tensor + tags : list of string + + Returns + ------- + tuple of torch.Tensor and torch.Tensor + """ + mask = self._get_decision(mutable) + assert len(mask) == mutable.n_candidates + out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask) + return self._tensor_reduction(mutable.reduction, out), mask + + def _select_with_mask(self, map_fn, candidates, mask): + if "BoolTensor" in mask.type(): + out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m] + elif "FloatTensor" in mask.type(): + out = [map_fn(*cand) * m for cand, m in zip(candidates, mask)] + else: + raise ValueError("Unrecognized mask") + return out + + def _tensor_reduction(self, reduction_type, tensor_list): + if reduction_type == "none": + return tensor_list + if not tensor_list: + return None # empty. return None for now + if len(tensor_list) == 1: + return tensor_list[0] + if reduction_type == "sum": + return sum(tensor_list) + if reduction_type == "mean": + return sum(tensor_list) / len(tensor_list) + if reduction_type == "concat": + return torch.cat(tensor_list, dim=1) + raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type)) + + def _get_decision(self, mutable): + """ + By default, this method checks whether `mutable.key` is already in the decision cache, + and returns the result without double-check. + + Parameters + ---------- + mutable : Mutable + + Returns + ------- + object + """ + if mutable.key not in self._cache: + raise ValueError("\"{}\" not found in decision cache.".format(mutable.key)) + result = self._cache[mutable.key] + logger.debug("Decision %s: %s", mutable.key, result) + return result diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py new file mode 100644 index 0000000000..d1d17764ba --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .trainer import PdartsTrainer diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py new file mode 100644 index 0000000000..8787b7ae40 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import copy + +import numpy as np +import torch.nn.functional as F + +from nni.nas.pytorch.darts import DartsMutator +from nni.nas.pytorch.mutables import LayerChoice + + +class PdartsMutator(DartsMutator): + + def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}): + self.pdarts_epoch_index = pdarts_epoch_index + self.pdarts_num_to_drop = pdarts_num_to_drop + if switches is None: + self.switches = {} + else: + self.switches = switches + + super(PdartsMutator, self).__init__(model) + + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + + switches = self.switches.get(mutable.key, [True for j in range(mutable.length)]) + + for index in range(len(switches)-1, -1, -1): + if switches[index] == False: + del(mutable.choices[index]) + mutable.length -= 1 + + self.switches[mutable.key] = switches + + def drop_paths(self): + for key in self.switches: + prob = F.softmax(self.choices[key], dim=-1).data.cpu().numpy() + + switches = self.switches[key] + idxs = [] + for j in range(len(switches)): + if switches[j]: + idxs.append(j) + if self.pdarts_epoch_index == len(self.pdarts_num_to_drop) - 1: + # for the last stage, drop all Zero operations + drop = self.get_min_k_no_zero(prob, idxs, self.pdarts_num_to_drop[self.pdarts_epoch_index]) + else: + drop = self.get_min_k(prob, self.pdarts_num_to_drop[self.pdarts_epoch_index]) + + for idx in drop: + switches[idxs[idx]] = False + return self.switches + + def get_min_k(self, input_in, k): + index = [] + for _ in range(k): + idx = np.argmin(input) + index.append(idx) + + return index + + def get_min_k_no_zero(self, w_in, idxs, k): + w = copy.deepcopy(w_in) + index = [] + if 0 in idxs: + zf = True + else: + zf = False + if zf: + w = w[1:] + index.append(0) + k = k - 1 + for _ in range(k): + idx = np.argmin(w) + w[idx] = 1 + if zf: + idx = idx + 1 + index.append(idx) + return index diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py new file mode 100644 index 0000000000..edc25fa360 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import logging + +from nni.nas.pytorch.callbacks import LRSchedulerCallback +from nni.nas.pytorch.darts import DartsTrainer +from nni.nas.pytorch.trainer import BaseTrainer + +from .mutator import PdartsMutator + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class PdartsTrainer(BaseTrainer): + + def __init__(self, model_creator, layers, metrics, + num_epochs, dataset_train, dataset_valid, + pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 2], + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None): + super(PdartsTrainer, self).__init__() + self.model_creator = model_creator + self.layers = layers + self.pdarts_num_layers = pdarts_num_layers + self.pdarts_num_to_drop = pdarts_num_to_drop + self.pdarts_epoch = len(pdarts_num_to_drop) + self.darts_parameters = { + "metrics": metrics, + "num_epochs": num_epochs, + "dataset_train": dataset_train, + "dataset_valid": dataset_valid, + "batch_size": batch_size, + "workers": workers, + "device": device, + "log_frequency": log_frequency + } + self.callbacks = callbacks if callbacks is not None else [] + + def train(self): + layers = self.layers + switches = None + for epoch in range(self.pdarts_epoch): + + layers = self.layers+self.pdarts_num_layers[epoch] + model, criterion, optim, lr_scheduler = self.model_creator(layers) + self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches) + + for callback in self.callbacks: + callback.build(model, self.mutator, self) + callback.on_epoch_begin(epoch) + + darts_callbacks = [] + if lr_scheduler is not None: + darts_callbacks.append(LRSchedulerCallback(lr_scheduler)) + + self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim, + callbacks=darts_callbacks, **self.darts_parameters) + logger.info("start pdarts training %s...", epoch) + + self.trainer.train() + + switches = self.mutator.drop_paths() + + for callback in self.callbacks: + callback.on_epoch_end(epoch) + + def validate(self): + self.model.validate() + + def checkpoint(self): + raise NotImplementedError("Not implemented yet") diff --git a/src/sdk/pynni/nni/nas/pytorch/trainer.py b/src/sdk/pynni/nni/nas/pytorch/trainer.py new file mode 100644 index 0000000000..879699f488 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/trainer.py @@ -0,0 +1,117 @@ +import json +import logging +from abc import abstractmethod + +import torch + +from .base_trainer import BaseTrainer + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + + +class TorchTensorEncoder(json.JSONEncoder): + def default(self, o): # pylint: disable=method-hidden + if isinstance(o, torch.Tensor): + olist = o.tolist() + if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)): + _logger.warning("Every element in %s is either 0 or 1. " + "You might consider convert it into bool.", olist) + return olist + return super().default(o) + + +class Trainer(BaseTrainer): + def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs, + dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks): + """ + Trainer initialization. + + Parameters + ---------- + model : nn.Module + Model with mutables. + mutator : BaseMutator + A mutator object that has been initialized with the model. + loss : callable + Called with logits and targets. Returns a loss tensor. + metrics : callable + Returns a dict that maps metrics keys to metrics data. + optimizer : Optimizer + Optimizer that optimizes the model. + num_epochs : int + Number of epochs of training. + dataset_train : torch.utils.data.Dataset + Dataset of training. + dataset_valid : torch.utils.data.Dataset + Dataset of validation/testing. + batch_size : int + Batch size. + workers : int + Number of workers used in data preprocessing. + device : torch.device + Device object. Either `torch.device("cuda")` or torch.device("cpu")`. When `None`, trainer will + automatic detects GPU and selects GPU first. + log_frequency : int + Number of mini-batches to log metrics. + callbacks : list of Callback + Callbacks to plug into the trainer. See Callbacks. + """ + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device + self.model = model + self.mutator = mutator + self.loss = loss + + self.metrics = metrics + self.optimizer = optimizer + + self.model.to(self.device) + self.mutator.to(self.device) + self.loss.to(self.device) + + self.num_epochs = num_epochs + self.dataset_train = dataset_train + self.dataset_valid = dataset_valid + self.batch_size = batch_size + self.workers = workers + self.log_frequency = log_frequency + self.callbacks = callbacks if callbacks is not None else [] + for callback in self.callbacks: + callback.build(self.model, self.mutator, self) + + @abstractmethod + def train_one_epoch(self, epoch): + pass + + @abstractmethod + def validate_one_epoch(self, epoch): + pass + + def train(self, validate=True): + for epoch in range(self.num_epochs): + for callback in self.callbacks: + callback.on_epoch_begin(epoch) + + # training + _logger.info("Epoch %d Training", epoch) + self.train_one_epoch(epoch) + + if validate: + # validation + _logger.info("Epoch %d Validating", epoch) + self.validate_one_epoch(epoch) + + for callback in self.callbacks: + callback.on_epoch_end(epoch) + + def validate(self): + self.validate_one_epoch(-1) + + def export(self, file): + mutator_export = self.mutator.export() + with open(file, "w") as f: + json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder) + + def checkpoint(self): + raise NotImplementedError("Not implemented yet") diff --git a/src/sdk/pynni/nni/nas/pytorch/utils.py b/src/sdk/pynni/nni/nas/pytorch/utils.py new file mode 100644 index 0000000000..b6766b07cf --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/utils.py @@ -0,0 +1,117 @@ +from collections import OrderedDict + +_counter = 0 + + +def global_mutable_counting(): + global _counter + _counter += 1 + return _counter + + +class AverageMeterGroup: + + def __init__(self): + self.meters = OrderedDict() + + def update(self, data): + for k, v in data.items(): + if k not in self.meters: + self.meters[k] = AverageMeter(k, ":4f") + self.meters[k].update(v) + + def __str__(self): + return " ".join(str(v) for _, v in self.meters.items()) + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f'): + """ + Initialization of AverageMeter + + Parameters + ---------- + name : str + Name to display. + fmt : str + Format string to print the values. + """ + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class StructuredMutableTreeNode: + """ + A structured representation of a search space. + A search space comes with a root (with `None` stored in its `mutable`), and a bunch of children in its `children`. + This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet, + the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a + ``Mutable`` (other than ``MutableScope``). + """ + + def __init__(self, mutable): + self.mutable = mutable + self.children = [] + + def add_child(self, mutable): + self.children.append(StructuredMutableTreeNode(mutable)) + return self.children[-1] + + def type(self): + return type(self.mutable) + + def __iter__(self): + return self.traverse() + + def traverse(self, order="pre", deduplicate=True, memo=None): + """ + Return a generator that generates a list of mutables in this tree. + + Parameters + ---------- + order : str + pre or post. If pre, current mutable is yield before children. Otherwise after. + deduplicate : bool + If true, mutables with the same key will not appear after the first appearance. + memo : dict + An auxiliary dict that memorize keys seen before, so that deduplication is possible. + + Returns + ------- + generator of Mutable + """ + if memo is None: + memo = set() + assert order in ["pre", "post"] + if order == "pre": + if self.mutable is not None: + if not deduplicate or self.mutable.key not in memo: + memo.add(self.mutable.key) + yield self.mutable + for child in self.children: + for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo): + yield m + if order == "post": + if self.mutable is not None: + if not deduplicate or self.mutable.key not in memo: + memo.add(self.mutable.key) + yield self.mutable diff --git a/src/sdk/pynni/nni/nas/tensorflow/__init__.py b/src/sdk/pynni/nni/nas/tensorflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2