From 2116189f0e144a605b1f688c8b6bba0a3ff92625 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 1 Nov 2019 09:36:40 +0800 Subject: [PATCH 01/10] NAS refactor initialization (#1676) --- examples/nas/darts/datasets.py | 25 +++ examples/nas/darts/image_ops.py | 186 ++++++++++++++++ examples/nas/darts/main.py | 198 ++++++++++++++++++ src/__init__.py | 0 src/sdk/__init__.py | 0 src/sdk/pynni/__init__.py | 0 src/sdk/pynni/nni/nas/__init__.py | 0 src/sdk/pynni/nni/nas/pytorch/__init__.py | 0 .../pynni/nni/nas/pytorch/darts/__init__.py | 2 + .../pynni/nni/nas/pytorch/darts/mutator.py | 19 ++ .../pynni/nni/nas/pytorch/darts/trainer.py | 161 ++++++++++++++ src/sdk/pynni/nni/nas/pytorch/mutables.py | 69 ++++++ src/sdk/pynni/nni/nas/pytorch/mutator.py | 72 +++++++ src/sdk/pynni/nni/nas/pytorch/trainer.py | 12 ++ src/sdk/pynni/nni/nas/tensorflow/__init__.py | 0 src/sdk/pynni/nni/nas/utils.py | 60 ++++++ 16 files changed, 804 insertions(+) create mode 100644 examples/nas/darts/datasets.py create mode 100644 examples/nas/darts/image_ops.py create mode 100644 examples/nas/darts/main.py create mode 100644 src/__init__.py create mode 100644 src/sdk/__init__.py create mode 100644 src/sdk/pynni/__init__.py create mode 100644 src/sdk/pynni/nni/nas/__init__.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/__init__.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/darts/__init__.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/darts/mutator.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/darts/trainer.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/mutables.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/mutator.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/trainer.py create mode 100644 src/sdk/pynni/nni/nas/tensorflow/__init__.py create mode 100644 src/sdk/pynni/nni/nas/utils.py diff --git a/examples/nas/darts/datasets.py b/examples/nas/darts/datasets.py new file mode 100644 index 0000000000..8fe0ab0fbf --- /dev/null +++ b/examples/nas/darts/datasets.py @@ -0,0 +1,25 @@ +from torchvision import transforms +from torchvision.datasets import CIFAR10 + + +def get_dataset(cls): + MEAN = [0.49139968, 0.48215827, 0.44653124] + STD = [0.24703233, 0.24348505, 0.26158768] + transf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip() + ] + normalize = [ + transforms.ToTensor(), + transforms.Normalize(MEAN, STD) + ] + + train_transform = transforms.Compose(transf + normalize) + valid_transform = transforms.Compose(normalize) + + if cls == "cifar10": + dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform) + dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform) + else: + raise NotImplementedError + return dataset_train, dataset_valid diff --git a/examples/nas/darts/image_ops.py b/examples/nas/darts/image_ops.py new file mode 100644 index 0000000000..ef25a6e830 --- /dev/null +++ b/examples/nas/darts/image_ops.py @@ -0,0 +1,186 @@ +import torch +import torch.nn as nn + + +PRIMITIVES = [ + 'max_pool_3x3', + 'avg_pool_3x3', + 'skip_connect', # identity + 'sep_conv_3x3', + 'sep_conv_5x5', + 'dil_conv_3x3', + 'dil_conv_5x5', + 'none' +] + +OPS = { + 'none': lambda C, stride, affine: Zero(stride), + 'avg_pool_3x3': lambda C, stride, affine: PoolBN('avg', C, 3, stride, 1, affine=affine), + 'max_pool_3x3': lambda C, stride, affine: PoolBN('max', C, 3, stride, 1, affine=affine), + 'skip_connect': lambda C, stride, affine: \ + Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), + 'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), + 'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), + 'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), + 'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5 + 'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9 + 'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine) +} + + +def drop_path_(x, drop_prob, training): + if training and drop_prob > 0.: + keep_prob = 1. - drop_prob + # per data point mask; assuming x in cuda. + mask = torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob) + x.div_(keep_prob).mul_(mask) + + return x + + +class DropPath_(nn.Module): + def __init__(self, p=0.): + """ [!] DropPath is inplace module + Args: + p: probability of an path to be zeroed. + """ + super().__init__() + self.p = p + + def extra_repr(self): + return 'p={}, inplace'.format(self.p) + + def forward(self, x): + drop_path_(x, self.p, self.training) + + return x + + +class PoolBN(nn.Module): + """ + AvgPool or MaxPool - BN + """ + def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True): + """ + Args: + pool_type: 'max' or 'avg' + """ + super().__init__() + if pool_type.lower() == 'max': + self.pool = nn.MaxPool2d(kernel_size, stride, padding) + elif pool_type.lower() == 'avg': + self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) + else: + raise ValueError() + + self.bn = nn.BatchNorm2d(C, affine=affine) + + def forward(self, x): + out = self.pool(x) + out = self.bn(out) + return out + + +class StdConv(nn.Module): + """ Standard conv + ReLU - Conv - BN + """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): + super().__init__() + self.net = nn.Sequential( + nn.ReLU(), + nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False), + nn.BatchNorm2d(C_out, affine=affine) + ) + + def forward(self, x): + return self.net(x) + + +class FacConv(nn.Module): + """ Factorized conv + ReLU - Conv(Kx1) - Conv(1xK) - BN + """ + def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True): + super().__init__() + self.net = nn.Sequential( + nn.ReLU(), + nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False), + nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False), + nn.BatchNorm2d(C_out, affine=affine) + ) + + def forward(self, x): + return self.net(x) + + +class DilConv(nn.Module): + """ (Dilated) depthwise separable conv + ReLU - (Dilated) depthwise separable - Pointwise - BN + If dilation == 2, 3x3 conv => 5x5 receptive field + 5x5 conv => 9x9 receptive field + """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): + super().__init__() + self.net = nn.Sequential( + nn.ReLU(), + nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, + bias=False), + nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(C_out, affine=affine) + ) + + def forward(self, x): + return self.net(x) + + +class SepConv(nn.Module): + """ Depthwise separable conv + DilConv(dilation=1) * 2 + """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): + super().__init__() + self.net = nn.Sequential( + DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine), + DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine) + ) + + def forward(self, x): + return self.net(x) + + +class Identity(nn.Module): + + def forward(self, x): + return x + + +class Zero(nn.Module): + def __init__(self, stride): + super().__init__() + self.stride = stride + + def forward(self, x): + if self.stride == 1: + return x * 0. + + # re-sizing by stride + return x[:, :, ::self.stride, ::self.stride] * 0. + + +class FactorizedReduce(nn.Module): + """ + Reduce feature map size by factorized pointwise(stride=2). + """ + def __init__(self, C_in, C_out, affine=True): + super().__init__() + self.relu = nn.ReLU() + self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.bn = nn.BatchNorm2d(C_out, affine=affine) + + def forward(self, x): + x = self.relu(x) + out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) + out = self.bn(out) + return out diff --git a/examples/nas/darts/main.py b/examples/nas/darts/main.py new file mode 100644 index 0000000000..36082555da --- /dev/null +++ b/examples/nas/darts/main.py @@ -0,0 +1,198 @@ +from argparse import ArgumentParser + +import datasets +import image_ops as ops +import nni.nas.pytorch as nas +import torch +import torch.nn as nn +from nni.nas.pytorch.darts import DartsTrainer + + +class SearchCell(nn.Module): + """ + Cell for search. + """ + + def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): + """ + Initialization a search cell. + + Parameters + ---------- + n_nodes: int + Number of nodes in current DAG. + channels_pp: int + Number of output channels from previous previous cell. + channels_p: int + Number of output channels from previous cell. + channels: int + Number of channels that will be used in the current DAG. + reduction_p: bool + Flag for whether the previous cell is reduction cell or not. + reduction: bool + Flag for whether the current cell is reduction cell or not. + """ + super().__init__() + self.reduction = reduction + self.n_nodes = n_nodes + + # If previous cell is reduction cell, current input size does not match with + # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. + if reduction_p: + self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False) + else: + self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False) + self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False) + + # generate dag + self.mutable_ops = nn.ModuleList() + for depth in range(self.n_nodes): + self.mutable_ops.append(nn.ModuleList()) + for i in range(2 + depth): # include 2 input nodes + # reduction should be used only for input node + stride = 2 if reduction and i < 2 else 1 + op = nas.mutables.LayerChoice([ops.PoolBN('max', channels, 3, stride, 1, affine=False), + ops.PoolBN('avg', channels, 3, stride, 1, affine=False), + ops.Identity() if stride == 1 else + ops.FactorizedReduce(channels, channels, affine=False), + ops.SepConv(channels, channels, 3, stride, 1, affine=False), + ops.SepConv(channels, channels, 5, stride, 2, affine=False), + ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False), + ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False), + ops.Zero(stride)], + key="r{}_d{}_i{}".format(reduction, depth, i)) + self.mutable_ops[depth].append(op) + + def forward(self, s0, s1): + # s0, s1 are the outputs of previous previous cell and previous cell, respectively. + tensors = [self.preproc0(s0), self.preproc1(s1)] + for ops in self.mutable_ops: + assert len(ops) == len(tensors) + cur_tensor = sum(op(tensor) for op, tensor in zip(ops, tensors)) + tensors.append(cur_tensor) + + output = torch.cat(tensors[2:], dim=1) + return output + + +class SearchCNN(nn.Module): + """ + Search CNN model + """ + + def __init__(self, in_channels, channels, n_classes, n_layers, n_nodes=4, stem_multiplier=3): + """ + Initializing a search channelsNN. + + Parameters + ---------- + in_channels: int + Number of channels in images. + channels: int + Number of channels used in the network. + n_classes: int + Number of classes. + n_layers: int + Number of cells in the whole network. + n_nodes: int + Number of nodes in a cell. + stem_multiplier: int + Multiplier of channels in STEM. + """ + super().__init__() + self.in_channels = in_channels + self.channels = channels + self.n_classes = n_classes + self.n_layers = n_layers + + c_cur = stem_multiplier * self.channels + self.stem = nn.Sequential( + nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), + nn.BatchNorm2d(c_cur) + ) + + # for the first cell, stem is used for both s0 and s1 + # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. + channels_pp, channels_p, c_cur = c_cur, c_cur, channels + + self.cells = nn.ModuleList() + reduction_p, reduction = False, False + for i in range(n_layers): + reduction_p, reduction = reduction, False + # Reduce featuremap size and double channels in 1/3 and 2/3 layer. + if i in [n_layers // 3, 2 * n_layers // 3]: + c_cur *= 2 + reduction = True + + cell = SearchCell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction) + self.cells.append(cell) + c_cur_out = c_cur * n_nodes + channels_pp, channels_p = channels_p, c_cur_out + + self.gap = nn.AdaptiveAvgPool2d(1) + self.linear = nn.Linear(channels_p, n_classes) + + def forward(self, x): + s0 = s1 = self.stem(x) + + for cell in self.cells: + s0, s1 = s1, cell(s0, s1) + + out = self.gap(s1) + out = out.view(out.size(0), -1) # flatten + logits = self.linear(out) + return logits + + +def accuracy(output, target, topk=(1,)): + """ Computes the precision@k for the specified values of k """ + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + # one-hot case + if target.ndimension() > 1: + target = target.max(1)[1] + + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = dict() + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() + return res + + +if __name__ == "__main__": + parser = ArgumentParser("darts") + parser.add_argument("--layers", default=4, type=int) + parser.add_argument("--nodes", default=2, type=int) + parser.add_argument("--batch-size", default=3, type=int) + parser.add_argument("--log-frequency", default=1, type=int) + args = parser.parse_args() + + dataset_train, dataset_valid = datasets.get_dataset("cifar10") + + model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes) + criterion = nn.CrossEntropyLoss() + + optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) + n_epochs = 50 + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, n_epochs, eta_min=0.001) + + trainer = DartsTrainer(model, + loss=criterion, + metrics=lambda output, target: accuracy(output, target, topk=(1,)), + model_optim=optim, + lr_scheduler=lr_scheduler, + num_epochs=50, + dataset_train=dataset_train, + dataset_valid=dataset_valid, + batch_size=args.batch_size, + log_frequency=args.log_frequency) + trainer.train() + trainer.finalize() + +# augment step +# ... diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sdk/__init__.py b/src/sdk/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sdk/pynni/__init__.py b/src/sdk/pynni/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sdk/pynni/nni/nas/__init__.py b/src/sdk/pynni/nni/nas/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sdk/pynni/nni/nas/pytorch/__init__.py b/src/sdk/pynni/nni/nas/pytorch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py b/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py new file mode 100644 index 0000000000..34e5f6e81c --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py @@ -0,0 +1,2 @@ +from .mutator import DartsMutator +from .trainer import DartsTrainer diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py new file mode 100644 index 0000000000..c086570602 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py @@ -0,0 +1,19 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from nni.nas.pytorch.mutables import LayerChoice +from nni.nas.pytorch.mutator import PyTorchMutator + + +class DartsMutator(PyTorchMutator): + + def before_build(self, model): + self.choices = nn.ParameterDict() + + def on_init_layer_choice(self, mutable: LayerChoice): + self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length)) + + def on_forward_layer_choice(self, mutable: LayerChoice, ops, *inputs): + weights = F.softmax(self.choices[mutable.key], dim=-1) + return sum(w * op(*inputs) for w, op in zip(weights, ops)) diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py new file mode 100644 index 0000000000..244d36bdf4 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py @@ -0,0 +1,161 @@ +import copy + +import torch +from torch import nn as nn + +from nni.nas.pytorch.trainer import Trainer +from nni.nas.utils import AverageMeterGroup, auto_device +from .mutator import DartsMutator + + +class DartsTrainer(Trainer): + def __init__(self, model, loss, metrics, + model_optim, lr_scheduler, num_epochs, dataset_train, dataset_valid, + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None): + self.model = model + self.loss = loss + self.metrics = metrics + self.mutator = mutator + if self.mutator is None: + self.mutator = DartsMutator(model) + self.model_optim = model_optim + self.lr_scheduler = lr_scheduler + self.num_epochs = num_epochs + self.dataset_train = dataset_train + self.dataset_valid = dataset_valid + self.device = auto_device() if device is None else device + self.log_frequency = log_frequency + + self.model.to(self.device) + self.loss.to(self.device) + self.mutator.to(self.device) + + self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), 3.0E-4, betas=(0.5, 0.999), + weight_decay=1.0E-3) + n_train = len(self.dataset_train) + split = n_train // 2 + indices = list(range(n_train)) + train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]) + valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]) + self.train_loader = torch.utils.data.DataLoader(self.dataset_train, + batch_size=batch_size, + sampler=train_sampler, + num_workers=workers) + self.valid_loader = torch.utils.data.DataLoader(self.dataset_train, + batch_size=batch_size, + sampler=valid_sampler, + num_workers=workers) + + def train_epoch(self, epoch): + self.model.train() + self.mutator.train() + lr = self.lr_scheduler.get_lr()[0] + meters = AverageMeterGroup() + for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)): + trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device) + val_X, val_y = val_X.to(self.device), val_y.to(self.device) + + # backup model for hessian + backup_model = copy.deepcopy(self.model.state_dict()) + # cannot deepcopy model because it will break the reference + + # phase 1. child network step + self.model_optim.zero_grad() + logits = self.model(trn_X) + loss = self.loss(logits, trn_y) + loss.backward() + # gradient clipping + nn.utils.clip_grad_norm_(self.model.parameters(), 5.) + self.model_optim.step() + + new_model = copy.deepcopy(self.model.state_dict()) + + # phase 2. architect step (alpha) + self.ctrl_optim.zero_grad() + # compute unrolled loss + self._unrolled_backward(trn_X, trn_y, val_X, val_y, backup_model, lr) + self.ctrl_optim.step() + + self.model.load_state_dict(new_model) + + metrics = self.metrics(logits, trn_y) + metrics["loss"] = loss.item() + meters.update(metrics) + if self.log_frequency is not None and step % self.log_frequency == 0: + print("Epoch {} Step [{}/{}] {}".format(epoch, step, len(self.train_loader), meters)) + + self.lr_scheduler.step() + + def validate_epoch(self, epoch): + self.model.eval() + self.mutator.eval() + meters = AverageMeterGroup() + with torch.no_grad(): + for step, (X, y) in enumerate(self.valid_loader): + X, y = X.to(self.device), y.to(self.device) + logits = self.model(X) + metrics = self.metrics(logits, y) + meters.update(metrics) + if self.log_frequency is not None and step % self.log_frequency == 0: + print("Epoch {} Step [{}/{}] {}".format(epoch, step, len(self.valid_loader), meters)) + + def train(self): + for epoch in range(self.num_epochs): + # training + print("Epoch {} Training".format(epoch)) + self.train_epoch(epoch) + + # validation + print("Epoch {} Validating".format(epoch)) + self.validate_epoch(epoch) + + def _unrolled_backward(self, trn_X, trn_y, val_X, val_y, backup_model, lr): + """ + Compute unrolled loss and backward its gradients + Parameters + ---------- + v_model: backup model before this step + lr: learning rate for virtual gradient step (same as net lr) + """ + loss = self.loss(self.model(val_X), val_y) + w_model = tuple(self.model.parameters()) + w_ctrl = tuple(self.mutator.parameters()) + w_grads = torch.autograd.grad(loss, w_model + w_ctrl) + d_model = w_grads[:len(w_model)] + d_ctrl = w_grads[len(w_model):] + + hessian = self._compute_hessian(backup_model, d_model, trn_X, trn_y) + with torch.no_grad(): + for param, d, h in zip(w_ctrl, d_ctrl, hessian): + param.grad = d - lr * h + + def _compute_hessian(self, model, dw, trn_X, trn_y): + """ + dw = dw` { L_val(w`, alpha) } + w+ = w + eps * dw + w- = w - eps * dw + hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps) + eps = 0.01 / ||dw|| + """ + self.model.load_state_dict(model) + + norm = torch.cat([w.view(-1) for w in dw]).norm() + eps = 0.01 / norm + + for e in [eps, -2. * eps]: + # w+ = w + eps*dw`, w- = w - eps*dw` + with torch.no_grad(): + for p, d in zip(self.model.parameters(), dw): + p += eps * d + + loss = self.loss(self.model(trn_X), trn_y) # TODO: should use model instead of self.model + if e > 0: + dalpha_pos = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w+) } + elif e < 0: + dalpha_neg = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w-) } + + hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)] + return hessian + + def finalize(self): + pass diff --git a/src/sdk/pynni/nni/nas/pytorch/mutables.py b/src/sdk/pynni/nni/nas/pytorch/mutables.py new file mode 100644 index 0000000000..4d2ecc1cce --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/mutables.py @@ -0,0 +1,69 @@ +import torch.nn as nn + +from nni.nas.utils import global_mutable_counting + + +class PyTorchMutable(nn.Module): + """ + Mutable is designed to function as a normal layer, with all necessary operators' weights. + States and weights of architectures should be included in mutator, instead of the layer itself. + + Mutable has a key, which marks the identity of the mutable. This key can be used by users to share + decisions among different mutables. In mutator's implementation, mutators should use the key to + distinguish different mutables. Mutables that share the same key should be "similar" to each other. + + Currently the default scope for keys is global. + """ + + def __init__(self, key=None): + super().__init__() + if key is not None: + self.key = key + else: + self.key = self.__class__.__name__ + str(global_mutable_counting()) + self.name = self.key + + def __deepcopy__(self, memodict=None): + raise NotImplementedError + + def set_mutator(self, mutator): + self.__dict__["mutator"] = mutator + + def forward(self, *inputs): + raise NotImplementedError("Mutable forward must be implemented") + + def __repr__(self): + return "{} ({})".format(self.name, self.key) + + def similar(self, other): + return self == other + + +class LayerChoice(PyTorchMutable): + def __init__(self, ops, key=None): + super().__init__(key=key) + self.length = len(ops) + self.choices = nn.ModuleList(ops) + + def forward(self, *inputs): + return self.mutator.on_forward(self, self.choices, *inputs) + + def similar(self, other): + return type(self) == type(other) and self.length == other.length + + +class InputChoice(PyTorchMutable): + def __init__(self, n_candidates, n_selected=None, reduction="mean", return_index=False, key=None): + super().__init__(key=key) + self.n_candidates = n_candidates + self.n_selected = n_selected + self.reduction = reduction + self.return_index = return_index + + def forward(self, *inputs): + assert len(inputs) == self.n_candidates, "Length of the input list must be equal to number of candidates." + return self.mutator.on_forward(self, *inputs) + + def similar(self, other): + return type(self) == type(other) and \ + self.n_candidates == other.n_candidates and self.n_selected and other.n_selected diff --git a/src/sdk/pynni/nni/nas/pytorch/mutator.py b/src/sdk/pynni/nni/nas/pytorch/mutator.py new file mode 100644 index 0000000000..331bdb42f8 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/mutator.py @@ -0,0 +1,72 @@ +import logging + +from torch import nn as nn + +from nni.nas.pytorch.mutables import PyTorchMutable +from nni.nas.utils import to_snake_case + +logger = logging.getLogger(__name__) + + +class PyTorchMutator(nn.Module): + def __init__(self, model): + super().__init__() + self.before_build(model) + self.parse_search_space(model) + self.after_build(model) + + def before_build(self, model): + pass + + def after_build(self, model): + pass + + def named_mutables(self, model): + # if distinct is true, the method will filter out those with duplicated keys + key2module = dict() + for name, module in model.named_modules(): + if isinstance(module, PyTorchMutable): + distinct = False + if module.key in key2module: + assert key2module[module.key].similar(module), "Mutable that share the same key must be similar " \ + "to each other" + else: + distinct = True + key2module[module.key] = module + yield name, module, distinct + + def __setattr__(self, key, value): + if key in ["model", "net", "network"]: + logger.warning("Think twice if you are including the network into mutator.") + return super().__setattr__(key, value) + + def parse_search_space(self, model): + for name, mutable, distinct in self.named_mutables(model): + mutable.name = name + mutable.set_mutator(self) + if not distinct: + continue + init_method_name = "on_init_{}".format(to_snake_case(mutable.__class__.__name__)) + if hasattr(self, init_method_name) and callable(getattr(self, init_method_name)): + getattr(self, init_method_name)(mutable) + else: + # fallback to general init + self.on_init_general(mutable) + + def on_init_general(self, mutable): + pass + + def on_forward_general(self, mutable, *inputs): + raise NotImplementedError("Forward has to be implemented") + + def on_forward(self, mutable, *inputs): + """Callback on forwarding a mutable""" + forward_method_name = "on_forward_{}".format(to_snake_case(mutable.__class__.__name__)) + if hasattr(self, forward_method_name) and callable(getattr(self, forward_method_name)): + return getattr(self, forward_method_name)(mutable, *inputs) + else: + # fallback to general forward + return self.on_forward_general(mutable, *inputs) + + def forward(self, *inputs): + raise NotImplementedError("Mutator is not forward-able") diff --git a/src/sdk/pynni/nni/nas/pytorch/trainer.py b/src/sdk/pynni/nni/nas/pytorch/trainer.py new file mode 100644 index 0000000000..27c20efc51 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/trainer.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod + + +class Trainer(ABC): + + @abstractmethod + def train(self): + raise NotImplementedError + + @abstractmethod + def finalize(self): + raise NotImplementedError diff --git a/src/sdk/pynni/nni/nas/tensorflow/__init__.py b/src/sdk/pynni/nni/nas/tensorflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sdk/pynni/nni/nas/utils.py b/src/sdk/pynni/nni/nas/utils.py new file mode 100644 index 0000000000..f6d5dfef65 --- /dev/null +++ b/src/sdk/pynni/nni/nas/utils.py @@ -0,0 +1,60 @@ +import re +from collections import OrderedDict + +import torch + +_counter = 0 + + +def global_mutable_counting(): + global _counter + _counter += 1 + return _counter + + +def to_snake_case(camel_case): + return re.sub('(?!^)([A-Z]+)', r'_\1', camel_case).lower() + + +def auto_device(): + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class AverageMeterGroup(object): + + def __init__(self): + self.meters = OrderedDict() + + def update(self, data): + for k, v in data.items(): + if k not in self.meters: + self.meters[k] = AverageMeter(k, ":4f") + self.meters[k].update(v) + + def __str__(self): + return " ".join(str(v) for _, v in self.meters.items()) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) From e238d34a9042b0b05aa87033e602dd98af2536f8 Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Tue, 5 Nov 2019 10:36:13 +0800 Subject: [PATCH 02/10] add overview doc for NAS (#1703) --- docs/en_US/NAS/Overview.md | 66 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 docs/en_US/NAS/Overview.md diff --git a/docs/en_US/NAS/Overview.md b/docs/en_US/NAS/Overview.md new file mode 100644 index 0000000000..bedf503b79 --- /dev/null +++ b/docs/en_US/NAS/Overview.md @@ -0,0 +1,66 @@ +# NNI Programming Interface for Neural Architecture Search (NAS) + +*This is an experimental feature, programming APIs are almost done, NAS trainers are under intensive development. ([NAS annotation](../AdvancedFeature/GeneralNasInterfaces.md) will become deprecated in future)* + +Automatic neural architecture search is taking an increasingly important role on finding better models. Recent research works have proved the feasibility of automatic NAS, and also found some models that could beat manually designed and tuned models. Some of representative works are [NASNet][2], [ENAS][1], [DARTS][3], [Network Morphism][4], and [Evolution][5]. There are new innovations keeping emerging. However, it takes great efforts to implement those algorithms, and it is hard to reuse code base of one algorithm for implementing another. + +To facilitate NAS innovations (e.g., design/implement new NAS models, compare different NAS models side-by-side), an easy-to-use and flexible programming interface is crucial. + +## Programming interface + +A new programming interface for designing and searching for a model is often demanded in two scenarios. + + 1. When designing a neural network, the designer may have multiple choices for a layer, sub-model, or connection, and not sure which one or a combination performs the best. It would be appealing to have an easy way to express the candidate layers/sub-models they want to try. + 2. For the researchers who are working on automatic NAS, they want to have an unified way to express the search space of neural architectures. And making unchanged trial code adapted to different searching algorithms. + +For expressing neural architecture search space, we provide two APIs: + +```python +# choose one ``op`` from ``ops``, for pytorch this is a module. +# ops: for pytorch ``ops`` is a list of modules, for tensorflow it is a list of keras layers. An example in pytroch: +# ops = [PoolBN('max', channels, 3, stride, 1, affine=False), +# PoolBN('avg', channels, 3, stride, 1, affine=False), +# FactorizedReduce(channels, channels, affine=False), +# SepConv(channels, channels, 3, stride, 1, affine=False), +# DilConv(channels, channels, 3, stride, 2, 2, affine=False)] +# key: the name of this ``LayerChoice`` instance +nni.nas.LayerChoice(ops, key) +# choose ``n_selected`` from ``n_candidates`` inputs. +# n_candidates: the number of candidate inputs +# n_selected: the number of chosen inputs +# reduction: reduction operation for the chosen inputs +# key: the name of this ``InputChoice`` instance +nni.nas.InputChoice(n_candidates, n_selected, reduction, key) +``` + +After writing your model with search space embedded in the model using the above two APIs, the next step is finding the best model from the search space. Similar to optimizers of deep learning models, the procedure of finding the best model from search space can be viewed as a type of optimizing process, we call it `NAS trainer`. There have been several NAS trainers, for example, `DartsTrainer` which uses SGD to train architecture weights and model weights iteratively, `ENASTrainer` which uses a controller to train the model. New and more efficient NAS trainers keep emerging in research community. + +NNI provides some popular NAS trainers, to use a NAS trainer, users could initialize a trainer after the model is defined: + +```python +# create a DartsTrainer +trainer = DartsTrainer(model, + loss=criterion, + metrics=lambda output, target: accuracy(output, target, topk=(1,)), + model_optim=optim, + lr_scheduler=lr_scheduler, + num_epochs=50, + dataset_train=dataset_train, + dataset_valid=dataset_valid, + batch_size=args.batch_size, + log_frequency=args.log_frequency) +# finding the best model from search space +trainer.train() +# export the best found model +trainer.export_model() +``` + +Different trainers could have different input arguments depending on their algorithms. After training, users could export the best one of the found models through `trainer.export_model()`. + +[Here](https://github.com/microsoft/nni/blob/dev-nas-refactor/examples/nas/darts/main.py) is a trial example using DartsTrainer. + +[1]: https://arxiv.org/abs/1802.03268 +[2]: https://arxiv.org/abs/1707.07012 +[3]: https://arxiv.org/abs/1806.09055 +[4]: https://arxiv.org/abs/1806.10282 +[5]: https://arxiv.org/abs/1703.01041 \ No newline at end of file From bb797e10e460c086a7de192dce2dae6681bbfcf0 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 8 Nov 2019 15:42:16 +0800 Subject: [PATCH 03/10] Update APIs and add preliminary support for ENAS macro space (#1714) * add enas macro * refactor example directory structure * update docstring --- examples/nas/.gitignore | 1 + examples/nas/darts/{main.py => model.py} | 63 +------- examples/nas/darts/{image_ops.py => ops.py} | 0 examples/nas/darts/search.py | 43 ++++++ examples/nas/darts/utils.py | 18 +++ examples/nas/enas/datasets.py | 25 +++ examples/nas/enas/enas_ops.py | 80 ++++++++++ examples/nas/enas/macro.py | 142 +++++++++++++++++ examples/nas/enas/ops.py | 80 ++++++++++ .../pynni/nni/nas/pytorch/darts/mutator.py | 5 +- .../pynni/nni/nas/pytorch/darts/trainer.py | 11 +- .../pynni/nni/nas/pytorch/enas/__init__.py | 2 + src/sdk/pynni/nni/nas/pytorch/enas/mutator.py | 126 +++++++++++++++ src/sdk/pynni/nni/nas/pytorch/enas/trainer.py | 120 +++++++++++++++ src/sdk/pynni/nni/nas/pytorch/mutables.py | 82 ++++++++-- src/sdk/pynni/nni/nas/pytorch/mutator.py | 145 +++++++++++++++++- src/sdk/pynni/nni/nas/pytorch/trainer.py | 2 +- 17 files changed, 854 insertions(+), 91 deletions(-) create mode 100644 examples/nas/.gitignore rename examples/nas/darts/{main.py => model.py} (72%) rename examples/nas/darts/{image_ops.py => ops.py} (100%) create mode 100644 examples/nas/darts/search.py create mode 100644 examples/nas/darts/utils.py create mode 100644 examples/nas/enas/datasets.py create mode 100644 examples/nas/enas/enas_ops.py create mode 100644 examples/nas/enas/macro.py create mode 100644 examples/nas/enas/ops.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/enas/__init__.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/enas/mutator.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/enas/trainer.py diff --git a/examples/nas/.gitignore b/examples/nas/.gitignore new file mode 100644 index 0000000000..1269488f7f --- /dev/null +++ b/examples/nas/.gitignore @@ -0,0 +1 @@ +data diff --git a/examples/nas/darts/main.py b/examples/nas/darts/model.py similarity index 72% rename from examples/nas/darts/main.py rename to examples/nas/darts/model.py index 36082555da..629831e0b7 100644 --- a/examples/nas/darts/main.py +++ b/examples/nas/darts/model.py @@ -1,11 +1,8 @@ -from argparse import ArgumentParser - -import datasets -import image_ops as ops -import nni.nas.pytorch as nas import torch import torch.nn as nn -from nni.nas.pytorch.darts import DartsTrainer + +import ops +from nni.nas import pytorch as nas class SearchCell(nn.Module): @@ -142,57 +139,3 @@ def forward(self, x): out = out.view(out.size(0), -1) # flatten logits = self.linear(out) return logits - - -def accuracy(output, target, topk=(1,)): - """ Computes the precision@k for the specified values of k """ - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - # one-hot case - if target.ndimension() > 1: - target = target.max(1)[1] - - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = dict() - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0) - res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() - return res - - -if __name__ == "__main__": - parser = ArgumentParser("darts") - parser.add_argument("--layers", default=4, type=int) - parser.add_argument("--nodes", default=2, type=int) - parser.add_argument("--batch-size", default=3, type=int) - parser.add_argument("--log-frequency", default=1, type=int) - args = parser.parse_args() - - dataset_train, dataset_valid = datasets.get_dataset("cifar10") - - model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes) - criterion = nn.CrossEntropyLoss() - - optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) - n_epochs = 50 - lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, n_epochs, eta_min=0.001) - - trainer = DartsTrainer(model, - loss=criterion, - metrics=lambda output, target: accuracy(output, target, topk=(1,)), - model_optim=optim, - lr_scheduler=lr_scheduler, - num_epochs=50, - dataset_train=dataset_train, - dataset_valid=dataset_valid, - batch_size=args.batch_size, - log_frequency=args.log_frequency) - trainer.train() - trainer.finalize() - -# augment step -# ... diff --git a/examples/nas/darts/image_ops.py b/examples/nas/darts/ops.py similarity index 100% rename from examples/nas/darts/image_ops.py rename to examples/nas/darts/ops.py diff --git a/examples/nas/darts/search.py b/examples/nas/darts/search.py new file mode 100644 index 0000000000..ad0650d156 --- /dev/null +++ b/examples/nas/darts/search.py @@ -0,0 +1,43 @@ +from argparse import ArgumentParser + +import datasets +import torch +import torch.nn as nn + +from model import SearchCNN +from nni.nas.pytorch.darts import DartsTrainer +from utils import accuracy + + +if __name__ == "__main__": + parser = ArgumentParser("darts") + parser.add_argument("--layers", default=4, type=int) + parser.add_argument("--nodes", default=2, type=int) + parser.add_argument("--batch-size", default=128, type=int) + parser.add_argument("--log-frequency", default=1, type=int) + args = parser.parse_args() + + dataset_train, dataset_valid = datasets.get_dataset("cifar10") + + model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes) + criterion = nn.CrossEntropyLoss() + + optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) + n_epochs = 50 + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, n_epochs, eta_min=0.001) + + trainer = DartsTrainer(model, + loss=criterion, + metrics=lambda output, target: accuracy(output, target, topk=(1,)), + model_optim=optim, + lr_scheduler=lr_scheduler, + num_epochs=50, + dataset_train=dataset_train, + dataset_valid=dataset_valid, + batch_size=args.batch_size, + log_frequency=args.log_frequency) + trainer.train() + trainer.export() + +# augment step +# ... diff --git a/examples/nas/darts/utils.py b/examples/nas/darts/utils.py new file mode 100644 index 0000000000..2aac457ad1 --- /dev/null +++ b/examples/nas/darts/utils.py @@ -0,0 +1,18 @@ +def accuracy(output, target, topk=(1,)): + """ Computes the precision@k for the specified values of k """ + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + # one-hot case + if target.ndimension() > 1: + target = target.max(1)[1] + + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = dict() + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() + return res \ No newline at end of file diff --git a/examples/nas/enas/datasets.py b/examples/nas/enas/datasets.py new file mode 100644 index 0000000000..8fe0ab0fbf --- /dev/null +++ b/examples/nas/enas/datasets.py @@ -0,0 +1,25 @@ +from torchvision import transforms +from torchvision.datasets import CIFAR10 + + +def get_dataset(cls): + MEAN = [0.49139968, 0.48215827, 0.44653124] + STD = [0.24703233, 0.24348505, 0.26158768] + transf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip() + ] + normalize = [ + transforms.ToTensor(), + transforms.Normalize(MEAN, STD) + ] + + train_transform = transforms.Compose(transf + normalize) + valid_transform = transforms.Compose(normalize) + + if cls == "cifar10": + dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform) + dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform) + else: + raise NotImplementedError + return dataset_train, dataset_valid diff --git a/examples/nas/enas/enas_ops.py b/examples/nas/enas/enas_ops.py new file mode 100644 index 0000000000..2df9088321 --- /dev/null +++ b/examples/nas/enas/enas_ops.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn + + +class StdConv(nn.Module): + def __init__(self, C_in, C_out): + super(StdConv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(C_out, affine=False), + nn.ReLU() + ) + + def forward(self, x): + return self.conv(x) + + +class PoolBranch(nn.Module): + def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False): + super().__init__() + self.preproc = StdConv(C_in, C_out) + if pool_type.lower() == 'max': + self.pool = nn.MaxPool2d(kernel_size, stride, padding) + elif pool_type.lower() == 'avg': + self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) + else: + raise ValueError() + self.bn = nn.BatchNorm2d(C_out, affine=affine) + + def forward(self, x): + out = self.preproc(x) + out = self.pool(out) + out = self.bn(out) + return out + + +class SeparableConv(nn.Module): + def __init__(self, C_in, C_out, kernel_size, stride, padding): + super(SeparableConv, self).__init__() + self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride, + groups=C_in, bias=False) + self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False) + + def forward(self, x): + out = self.depthwise(x) + out = self.pointwise(out) + return out + + +class ConvBranch(nn.Module): + def __init__(self, C_in, C_out, kernel_size, stride, padding, separable): + super(ConvBranch, self).__init__() + self.preproc = StdConv(C_in, C_out) + if separable: + self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding) + else: + self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding) + self.postproc = nn.Sequential( + nn.BatchNorm2d(C_out, affine=False), + nn.ReLU() + ) + + def forward(self, x): + out = self.preproc(x) + out = self.conv(out) + out = self.postproc(out) + return out + + +class FactorizedReduce(nn.Module): + def __init__(self, C_in, C_out, affine=False): + super().__init__() + self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.bn = nn.BatchNorm2d(C_out, affine=affine) + + def forward(self, x): + out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) + out = self.bn(out) + return out diff --git a/examples/nas/enas/macro.py b/examples/nas/enas/macro.py new file mode 100644 index 0000000000..8d2ca21522 --- /dev/null +++ b/examples/nas/enas/macro.py @@ -0,0 +1,142 @@ +from argparse import ArgumentParser +import torch +import torch.nn as nn + +import datasets +from ops import FactorizedReduce, ConvBranch, PoolBranch +from nni.nas.pytorch import mutables, enas + + +class ENASLayer(nn.Module): + + def __init__(self, layer_id, in_filters, out_filters): + super().__init__() + self.in_filters = in_filters + self.out_filters = out_filters + self.mutable = mutables.LayerChoice([ + ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False), + ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True), + ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False), + ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True), + PoolBranch('avg', in_filters, out_filters, 3, 1, 1), + PoolBranch('max', in_filters, out_filters, 3, 1, 1) + ]) + if layer_id > 0: + self.skipconnect = mutables.InputChoice(layer_id, n_selected=None, reduction="sum") + else: + self.skipconnect = None + self.batch_norm = nn.BatchNorm2d(out_filters, affine=False) + self.mutable_scope = mutables.MutableScope("layer_{}".format(layer_id)) + + def forward(self, prev_layers): + with self.mutable_scope: + out = self.mutable(prev_layers[-1]) + if self.skipconnect is not None: + connection = self.skipconnect(prev_layers[:-1], + ["layer_{}".format(i) for i in range(len(prev_layers) - 1)]) + if connection is not None: + out += connection + return self.batch_norm(out) + + +class GeneralNetwork(nn.Module): + def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10, + dropout_rate=0.0): + super().__init__() + self.num_layers = num_layers + self.num_classes = num_classes + self.out_filters = out_filters + + self.stem = nn.Sequential( + nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_filters) + ) + + pool_distance = self.num_layers // 3 + self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1] + self.dropout_rate = dropout_rate + self.dropout = nn.Dropout(self.dropout_rate) + + self.layers = nn.ModuleList() + self.pool_layers = nn.ModuleList() + for layer_id in range(self.num_layers): + if layer_id in self.pool_layers_idx: + self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters)) + self.layers.append(ENASLayer(layer_id, self.out_filters, self.out_filters)) + + self.gap = nn.AdaptiveAvgPool2d(1) + self.dense = nn.Linear(self.out_filters, self.num_classes) + + def forward(self, x): + bs = x.size(0) + cur = self.stem(x) + + layers = [cur] + + for layer_id in range(self.num_layers): + cur = self.layers[layer_id](layers) + layers.append(cur) + if layer_id in self.pool_layers_idx: + for i, layer in enumerate(layers): + layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer) + cur = layers[-1] + + cur = self.gap(cur).view(bs, -1) + cur = self.dropout(cur) + logits = self.dense(cur) + return logits + + +def accuracy(output, target, topk=(1,)): + """ Computes the precision@k for the specified values of k """ + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + # one-hot case + if target.ndimension() > 1: + target = target.max(1)[1] + + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = dict() + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() + return res + + +def reward_accuracy(output, target, topk=(1,)): + batch_size = target.size(0) + _, predicted = torch.max(output.data, 1) + return (predicted == target).sum().item() / batch_size + + +if __name__ == "__main__": + parser = ArgumentParser("enas") + parser.add_argument("--batch-size", default=3, type=int) + parser.add_argument("--log-frequency", default=1, type=int) + args = parser.parse_args() + + dataset_train, dataset_valid = datasets.get_dataset("cifar10") + + model = GeneralNetwork() + criterion = nn.CrossEntropyLoss() + + n_epochs = 310 + optim = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs, eta_min=0.001) + + trainer = enas.EnasTrainer(model, + loss=criterion, + metrics=accuracy, + reward_function=reward_accuracy, + optimizer=optim, + lr_scheduler=lr_scheduler, + batch_size=args.batch_size, + num_epochs=n_epochs, + dataset_train=dataset_train, + dataset_valid=dataset_valid, + log_frequency=args.log_frequency) + trainer.train() diff --git a/examples/nas/enas/ops.py b/examples/nas/enas/ops.py new file mode 100644 index 0000000000..2df9088321 --- /dev/null +++ b/examples/nas/enas/ops.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn + + +class StdConv(nn.Module): + def __init__(self, C_in, C_out): + super(StdConv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(C_out, affine=False), + nn.ReLU() + ) + + def forward(self, x): + return self.conv(x) + + +class PoolBranch(nn.Module): + def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False): + super().__init__() + self.preproc = StdConv(C_in, C_out) + if pool_type.lower() == 'max': + self.pool = nn.MaxPool2d(kernel_size, stride, padding) + elif pool_type.lower() == 'avg': + self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) + else: + raise ValueError() + self.bn = nn.BatchNorm2d(C_out, affine=affine) + + def forward(self, x): + out = self.preproc(x) + out = self.pool(out) + out = self.bn(out) + return out + + +class SeparableConv(nn.Module): + def __init__(self, C_in, C_out, kernel_size, stride, padding): + super(SeparableConv, self).__init__() + self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride, + groups=C_in, bias=False) + self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False) + + def forward(self, x): + out = self.depthwise(x) + out = self.pointwise(out) + return out + + +class ConvBranch(nn.Module): + def __init__(self, C_in, C_out, kernel_size, stride, padding, separable): + super(ConvBranch, self).__init__() + self.preproc = StdConv(C_in, C_out) + if separable: + self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding) + else: + self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding) + self.postproc = nn.Sequential( + nn.BatchNorm2d(C_out, affine=False), + nn.ReLU() + ) + + def forward(self, x): + out = self.preproc(x) + out = self.conv(out) + out = self.postproc(out) + return out + + +class FactorizedReduce(nn.Module): + def __init__(self, C_in, C_out, affine=False): + super().__init__() + self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.bn = nn.BatchNorm2d(C_out, affine=affine) + + def forward(self, x): + out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) + out = self.bn(out) + return out diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py index c086570602..ef5dcec806 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py @@ -14,6 +14,5 @@ def before_build(self, model): def on_init_layer_choice(self, mutable: LayerChoice): self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length)) - def on_forward_layer_choice(self, mutable: LayerChoice, ops, *inputs): - weights = F.softmax(self.choices[mutable.key], dim=-1) - return sum(w * op(*inputs) for w, op in zip(weights, ops)) + def on_calc_layer_choice_mask(self, mutable: LayerChoice): + return F.softmax(self.choices[mutable.key], dim=-1) diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py index 244d36bdf4..72ac427c11 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py @@ -61,7 +61,8 @@ def train_epoch(self, epoch): # phase 1. child network step self.model_optim.zero_grad() - logits = self.model(trn_X) + with self.mutator.forward_pass(): + logits = self.model(trn_X) loss = self.loss(logits, trn_y) loss.backward() # gradient clipping @@ -117,7 +118,8 @@ def _unrolled_backward(self, trn_X, trn_y, val_X, val_y, backup_model, lr): v_model: backup model before this step lr: learning rate for virtual gradient step (same as net lr) """ - loss = self.loss(self.model(val_X), val_y) + with self.mutator.forward_pass(): + loss = self.loss(self.model(val_X), val_y) w_model = tuple(self.model.parameters()) w_ctrl = tuple(self.mutator.parameters()) w_grads = torch.autograd.grad(loss, w_model + w_ctrl) @@ -148,7 +150,8 @@ def _compute_hessian(self, model, dw, trn_X, trn_y): for p, d in zip(self.model.parameters(), dw): p += eps * d - loss = self.loss(self.model(trn_X), trn_y) # TODO: should use model instead of self.model + with self.mutator.forward_pass(): + loss = self.loss(self.model(trn_X), trn_y) if e > 0: dalpha_pos = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w+) } elif e < 0: @@ -157,5 +160,5 @@ def _compute_hessian(self, model, dw, trn_X, trn_y): hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)] return hessian - def finalize(self): + def export(self): pass diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/__init__.py b/src/sdk/pynni/nni/nas/pytorch/enas/__init__.py new file mode 100644 index 0000000000..78f066ff6c --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/enas/__init__.py @@ -0,0 +1,2 @@ +from .mutator import EnasMutator +from .trainer import EnasTrainer diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py new file mode 100644 index 0000000000..93dad9c77c --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py @@ -0,0 +1,126 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nni.nas.pytorch.mutator import PyTorchMutator + + +class StackedLSTMCell(nn.Module): + def __init__(self, layers, size, bias): + super().__init__() + self.lstm_num_layers = layers + self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias) + for _ in range(self.lstm_num_layers)]) + + def forward(self, inputs, hidden): + prev_c, prev_h = hidden + next_c, next_h = [], [] + for i, m in enumerate(self.lstm_modules): + curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i])) + next_c.append(curr_c) + next_h.append(curr_h) + inputs = curr_h[-1] + return next_c, next_h + + +class EnasMutator(PyTorchMutator): + def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, anchor_extra_step=False, + skip_target=0.4): + self.lstm_size = lstm_size + self.lstm_num_layers = lstm_num_layers + self.tanh_constant = tanh_constant + self.max_layer_choice = 0 + self.anchor_extra_step = anchor_extra_step + self.skip_target = skip_target + super().__init__(model) + + def before_build(self, model): + self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False) + self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False) + self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) + self.v_attn = nn.Linear(self.lstm_size, 1, bias=False) + self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1) + self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) + self.cross_entropy_loss = nn.CrossEntropyLoss() + + def after_build(self, model): + self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size) + self.soft = nn.Linear(self.lstm_size, self.max_layer_choice) + + def before_pass(self): + super().before_pass() + self._anchors_hid = dict() + self._selected_layers = [] + self._selected_inputs = [] + self._inputs = self.g_emb.data + self._c = [torch.zeros((1, self.lstm_size), + dtype=self._inputs.dtype, + device=self._inputs.device) for _ in range(self.lstm_num_layers)] + self._h = [torch.zeros((1, self.lstm_size), + dtype=self._inputs.dtype, + device=self._inputs.device) for _ in range(self.lstm_num_layers)] + self.sample_log_prob = 0 + self.sample_entropy = 0 + self.sample_skip_penalty = 0 + + def _lstm_next_step(self): + self._c, self._h = self.lstm(self._inputs, (self._c, self._h)) + + def _mark_anchor(self, key): + self._anchors_hid[key] = self._h[-1] + + def on_init_layer_choice(self, mutable): + if self.max_layer_choice == 0: + self.max_layer_choice = mutable.length + assert self.max_layer_choice == mutable.length, \ + "ENAS mutator requires all layer choice have the same number of candidates." + + def on_calc_layer_choice_mask(self, mutable): + self._lstm_next_step() + logit = self.soft(self._h[-1]) + if self.tanh_constant is not None: + logit = self.tanh_constant * torch.tanh(logit) + branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) + log_prob = self.cross_entropy_loss(logit, branch_id) + self.sample_log_prob += log_prob + entropy = (log_prob * torch.exp(-log_prob)).detach() + self.sample_entropy += entropy + self._inputs = self.embedding(branch_id) + self._selected_layers.append(branch_id.item()) + return F.one_hot(branch_id).bool().view(-1) + + def on_calc_input_choice_mask(self, mutable, semantic_labels): + if mutable.n_selected is None: + query, anchors = [], [] + for label in semantic_labels: + if label not in self._anchors_hid: + self._lstm_next_step() + self._mark_anchor(label) # empty loop, fill not found + query.append(self.attn_anchor(self._anchors_hid[label])) + anchors.append(self._anchors_hid[label]) + query = torch.cat(query, 0) + query = torch.tanh(query + self.attn_query(self._h[-1])) + query = self.v_attn(query) + logit = torch.cat([-query, query], 1) + if self.tanh_constant is not None: + logit = self.tanh_constant * torch.tanh(logit) + + skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) + skip_prob = torch.sigmoid(logit) + kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets)) + self.sample_skip_penalty += kl + + log_prob = self.cross_entropy_loss(logit, skip) + self.sample_log_prob += torch.sum(log_prob) + entropy = (log_prob * torch.exp(-log_prob)).detach() + self.sample_entropy += torch.sum(entropy) + + self.inputs = torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip)) + self._selected_inputs.append(skip) + return skip.bool() + else: + assert mutable.n_selected == 1, "Input choice must select exactly one or any in ENAS." + raise NotImplementedError + + def exit_mutable_scope(self, mutable_scope): + self._mark_anchor(mutable_scope.key) diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py b/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py new file mode 100644 index 0000000000..7bc24ad16f --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py @@ -0,0 +1,120 @@ +import torch +import torch.optim as optim + +from nni.nas.pytorch.trainer import Trainer +from nni.nas.utils import AverageMeterGroup, auto_device +from .mutator import EnasMutator + + +class EnasTrainer(Trainer): + def __init__(self, model, loss, metrics, reward_function, + optimizer, num_epochs, dataset_train, dataset_valid, lr_scheduler=None, + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, + entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, + mutator_lr=0.00035): + self.model = model + self.loss = loss + self.metrics = metrics + self.reward_function = reward_function + self.mutator = mutator + if self.mutator is None: + self.mutator = EnasMutator(model) + self.optim = optimizer + self.mut_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr) + self.lr_scheduler = lr_scheduler + self.num_epochs = num_epochs + self.dataset_train = dataset_train + self.dataset_valid = dataset_valid + self.device = auto_device() if device is None else device + self.log_frequency = log_frequency + self.entropy_weight = entropy_weight + self.skip_weight = skip_weight + self.baseline_decay = baseline_decay + self.baseline = 0. + + self.model.to(self.device) + self.loss.to(self.device) + self.mutator.to(self.device) + + n_train = len(self.dataset_train) + split = n_train // 10 + indices = list(range(n_train)) + train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split]) + valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:]) + self.train_loader = torch.utils.data.DataLoader(self.dataset_train, + batch_size=batch_size, + sampler=train_sampler, + num_workers=workers) + self.valid_loader = torch.utils.data.DataLoader(self.dataset_train, + batch_size=batch_size, + sampler=valid_sampler, + num_workers=workers) + self.test_loader = torch.utils.data.DataLoader(self.dataset_valid, + batch_size=batch_size, + num_workers=workers) + + def train_epoch(self, epoch): + self.model.train() + self.mutator.train() + + for phase in ["model", "mutator"]: + if phase == "model": + self.model.train() + self.mutator.eval() + else: + self.model.eval() + self.mutator.train() + loader = self.train_loader if phase == "model" else self.valid_loader + meters = AverageMeterGroup() + for step, (x, y) in enumerate(loader): + x, y = x.to(self.device), y.to(self.device) + self.optim.zero_grad() + self.mut_optim.zero_grad() + + with self.mutator.forward_pass(): + logits = self.model(x) + metrics = self.metrics(logits, y) + + if phase == "model": + loss = self.loss(logits, y) + loss.backward() + self.optim.step() + else: + reward = self.reward_function(logits, y) + if self.entropy_weight is not None: + reward += self.entropy_weight * self.mutator.sample_entropy + self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) + self.baseline = self.baseline.detach().item() + loss = self.mutator.sample_log_prob * (reward - self.baseline) + if self.skip_weight: + loss += self.skip_weight * self.mutator.sample_skip_penalty + loss.backward() + self.mut_optim.step() + metrics["reward"] = reward + metrics["loss"] = loss.item() + meters.update(metrics) + + if self.log_frequency is not None and step % self.log_frequency == 0: + print("Epoch {} {} Step [{}/{}] {}".format(epoch, phase.capitalize(), step, + len(loader), meters)) + # print(self.mutator._selected_layers) + # print(self.mutator._selected_inputs) + + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + def validate_epoch(self, epoch): + pass + + def train(self): + for epoch in range(self.num_epochs): + # training + print("Epoch {} Training".format(epoch)) + self.train_epoch(epoch) + + # validation + print("Epoch {} Validating".format(epoch)) + self.validate_epoch(epoch) + + def export(self): + pass diff --git a/src/sdk/pynni/nni/nas/pytorch/mutables.py b/src/sdk/pynni/nni/nas/pytorch/mutables.py index 4d2ecc1cce..456fe7a498 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutables.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutables.py @@ -18,51 +18,101 @@ class PyTorchMutable(nn.Module): def __init__(self, key=None): super().__init__() if key is not None: - self.key = key + if not isinstance(key, str): + key = str(key) + print("Warning: key \"{}\" is not string, converted to string.".format(key)) + self._key = key else: - self.key = self.__class__.__name__ + str(global_mutable_counting()) + self._key = self.__class__.__name__ + str(global_mutable_counting()) self.name = self.key def __deepcopy__(self, memodict=None): - raise NotImplementedError + raise NotImplementedError("Deep copy doesn't work for mutables.") + + def __enter__(self): + self._check_built() + return super().__enter__() + + def __call__(self, *args, **kwargs): + self._check_built() + return super().__call__(*args, **kwargs) def set_mutator(self, mutator): self.__dict__["mutator"] = mutator def forward(self, *inputs): - raise NotImplementedError("Mutable forward must be implemented") + raise NotImplementedError("Mutable forward must be implemented.") - def __repr__(self): - return "{} ({})".format(self.name, self.key) + @property + def key(self): + return self._key def similar(self, other): return self == other + def _check_built(self): + if not hasattr(self, "mutator"): + raise ValueError( + "Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__" + "so that trainer can locate all your mutables. See NNI docs for more details.".format(self)) + + def __repr__(self): + return "{} ({})".format(self.name, self.key) + + +class MutableScope(PyTorchMutable): + """ + Mutable scope labels a subgraph to help mutators make better decisions. Mutators get notified when a mutable scope + is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch + corresponding events, and do status dump or update. + """ + + def __init__(self, key): + super().__init__(key=key) + + def __enter__(self): + self.mutator.enter_mutable_scope(self) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.mutator.exit_mutable_scope(self) + class LayerChoice(PyTorchMutable): - def __init__(self, ops, key=None): + def __init__(self, op_candidates, reduction="mean", return_mask=False, key=None): super().__init__(key=key) - self.length = len(ops) - self.choices = nn.ModuleList(ops) + self.length = len(op_candidates) + self.choices = nn.ModuleList(op_candidates) + self.reduction = reduction + self.return_mask = return_mask def forward(self, *inputs): - return self.mutator.on_forward(self, self.choices, *inputs) + out, mask = self.mutator.on_forward(self, *inputs) + if self.return_mask: + return out, mask + return out def similar(self, other): return type(self) == type(other) and self.length == other.length class InputChoice(PyTorchMutable): - def __init__(self, n_candidates, n_selected=None, reduction="mean", return_index=False, key=None): + def __init__(self, n_candidates, n_selected=None, reduction="mean", return_mask=False, key=None): super().__init__(key=key) + assert n_candidates > 0, "Number of candidates must be greater than 0." self.n_candidates = n_candidates self.n_selected = n_selected self.reduction = reduction - self.return_index = return_index - - def forward(self, *inputs): - assert len(inputs) == self.n_candidates, "Length of the input list must be equal to number of candidates." - return self.mutator.on_forward(self, *inputs) + self.return_mask = return_mask + + def forward(self, optional_inputs, semantic_labels=None): + assert len(optional_inputs) == self.n_candidates, \ + "Length of the input list must be equal to number of candidates." + if semantic_labels is None: + semantic_labels = ["default_label"] * self.n_candidates + out, mask = self.mutator.on_forward(self, optional_inputs, semantic_labels) + if self.return_mask: + return out, mask + return out def similar(self, other): return type(self) == type(other) and \ diff --git a/src/sdk/pynni/nni/nas/pytorch/mutator.py b/src/sdk/pynni/nni/nas/pytorch/mutator.py index 331bdb42f8..55d742e9af 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutator.py @@ -1,6 +1,8 @@ import logging +from contextlib import contextmanager -from torch import nn as nn +import torch +import torch.nn as nn from nni.nas.pytorch.mutables import PyTorchMutable from nni.nas.utils import to_snake_case @@ -28,8 +30,8 @@ def named_mutables(self, model): if isinstance(module, PyTorchMutable): distinct = False if module.key in key2module: - assert key2module[module.key].similar(module), "Mutable that share the same key must be similar " \ - "to each other" + assert key2module[module.key].similar(module), \ + "Mutable \"{}\" that share the same key must be similar to each other".format(module.key) else: distinct = True key2module[module.key] = module @@ -56,11 +58,35 @@ def parse_search_space(self, model): def on_init_general(self, mutable): pass - def on_forward_general(self, mutable, *inputs): - raise NotImplementedError("Forward has to be implemented") + @contextmanager + def forward_pass(self): + self.before_pass() + try: + yield self + finally: + self.after_pass() + + def before_pass(self): + self._in_forward_pass = True + self._cache = dict() + + def after_pass(self): + self._in_forward_pass = False + + def enter_mutable_scope(self, mutable_scope): + pass + + def exit_mutable_scope(self, mutable_scope): + pass + + def forward(self, *inputs): + raise NotImplementedError("Mutator is not forward-able") def on_forward(self, mutable, *inputs): """Callback on forwarding a mutable""" + if not hasattr(self, "_in_forward_pass") or not self._in_forward_pass: + raise ValueError("Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call " + "super().before_pass() and after_pass() in your override method?") forward_method_name = "on_forward_{}".format(to_snake_case(mutable.__class__.__name__)) if hasattr(self, forward_method_name) and callable(getattr(self, forward_method_name)): return getattr(self, forward_method_name)(mutable, *inputs) @@ -68,5 +94,110 @@ def on_forward(self, mutable, *inputs): # fallback to general forward return self.on_forward_general(mutable, *inputs) - def forward(self, *inputs): - raise NotImplementedError("Mutator is not forward-able") + def on_forward_general(self, mutable, *inputs): + raise NotImplementedError("Forward has to be implemented") + + def on_forward_layer_choice(self, mutable, *inputs): + """ + Callback of layer choice forward. Override if you are an advanced user. + On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers + (either by switch or by weights), then it will reduce the list of all tensor outputs with the policy speicified + in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`. + + Parameters + ---------- + mutable: LayerChoice + inputs: list of torch.Tensor + + Returns + ------- + torch.Tensor + """ + def _map_fn(op, *inputs): + return op(*inputs) + mask = self._cache.setdefault(mutable.key, self.on_calc_layer_choice_mask(mutable)) + out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask) + return self._tensor_reduction(mutable.reduction, out), mask + + def on_forward_input_choice(self, mutable, tensor_list, semantic_labels): + """ + Callback of input choice forward. Override if you are an advanced user. + On default, this method calls :meth:`on_calc_input_choice_mask` with `semantic_labels` + to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce + the list of all tensor outputs with the policy speicified in `mutable.reduction`. It will also cache the + mask with corresponding `mutable.key`. + + Parameters + ---------- + mutable: InputChoice + inputs: list of torch.Tensor + + Returns + ------- + torch.Tensor + """ + mask = self._cache.setdefault(mutable.key, self.on_calc_input_choice_mask(mutable, semantic_labels)) + out = self._select_with_mask(lambda x: x, [(t, ) for t in tensor_list], mask) + return self._tensor_reduction(mutable.reduction, out), mask + + def on_calc_layer_choice_mask(self, mutable): + """ + Recommended to override. Calculate a mask tensor for a layer choice. + + Parameters + ---------- + mutable: LayerChoice + Corresponding layer choice object. + + Returns + ------- + torch.Tensor + Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool, + the numbers are treated as switch. + """ + raise NotImplementedError("Layer choice mask calculation must be implemented") + + def on_calc_input_choice_mask(self, mutable, semantic_labels): + """ + Recommended to override. Calculate a mask tensor for a input choice. + + Parameters + ---------- + mutable: InputChoice + Corresponding input choice object. + semantic_labels: list of string + The name of labels of input tensors given by user. Usually it's a + :class:`~nni.nas.pytorch.mutables.MutableScope` marked by user. + + Returns + ------- + torch.Tensor + Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool, + the numbers are treated as switch. + """ + raise NotImplementedError("Input choice mask calculation must be implemented") + + def _select_with_mask(self, map_fn, candidates, mask): + if "BoolTensor" in mask.type(): + # print(candidates[0], len(mask)) + out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m] + elif "FloatTensor" in mask.type(): + out = [map_fn(*cand) * m for cand, m in zip(candidates, mask)] + else: + raise ValueError("Unrecognized mask") + return out + + def _tensor_reduction(self, reduction_type, tensor_list): + if tensor_list == "none": + return tensor_list + if not tensor_list: + return None # empty. return None for now + if len(tensor_list) == 1: + return tensor_list[0] + if reduction_type == "sum": + return sum(tensor_list) + if reduction_type == "mean": + return sum(tensor_list) / len(tensor_list) + if reduction_type == "concat": + return torch.cat(tensor_list, dim=1) + raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type)) diff --git a/src/sdk/pynni/nni/nas/pytorch/trainer.py b/src/sdk/pynni/nni/nas/pytorch/trainer.py index 27c20efc51..6327e9a229 100644 --- a/src/sdk/pynni/nni/nas/pytorch/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/trainer.py @@ -8,5 +8,5 @@ def train(self): raise NotImplementedError @abstractmethod - def finalize(self): + def export(self): raise NotImplementedError From d1d10de7684b6a2ebc0ee09d0908978f3827bd56 Mon Sep 17 00:00:00 2001 From: Chi Song <27178119+squirrelsc@users.noreply.github.com> Date: Thu, 14 Nov 2019 17:14:00 +0800 Subject: [PATCH 04/10] pdarts implementation (export is not included) (#1730) --- .gitignore | 1 + examples/nas/.gitignore | 2 +- examples/nas/darts/model.py | 141 ------------------ examples/nas/darts/search.py | 12 +- examples/nas/pdarts/.gitignore | 2 + examples/nas/pdarts/datasets.py | 25 ++++ examples/nas/pdarts/main.py | 65 ++++++++ .../pynni/nni/nas/pytorch/darts/__init__.py | 2 + .../pynni/nni/nas/pytorch/darts/cnn_cell.py | 69 +++++++++ .../nni/nas/pytorch/darts/cnn_network.py | 73 +++++++++ .../pynni/nni/nas/pytorch/darts/cnn_ops.py | 23 +-- .../pynni/nni/nas/pytorch/darts/trainer.py | 3 +- src/sdk/pynni/nni/nas/pytorch/enas/mutator.py | 4 +- src/sdk/pynni/nni/nas/pytorch/modules.py | 9 ++ src/sdk/pynni/nni/nas/pytorch/mutables.py | 8 +- .../pynni/nni/nas/pytorch/pdarts/__init__.py | 1 + .../pynni/nni/nas/pytorch/pdarts/mutator.py | 93 ++++++++++++ .../pynni/nni/nas/pytorch/pdarts/trainer.py | 54 +++++++ 18 files changed, 421 insertions(+), 166 deletions(-) delete mode 100644 examples/nas/darts/model.py create mode 100644 examples/nas/pdarts/.gitignore create mode 100644 examples/nas/pdarts/datasets.py create mode 100644 examples/nas/pdarts/main.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py rename examples/nas/darts/ops.py => src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py (93%) create mode 100644 src/sdk/pynni/nni/nas/pytorch/modules.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py diff --git a/.gitignore b/.gitignore index e96b14efc6..83049a476e 100644 --- a/.gitignore +++ b/.gitignore @@ -80,6 +80,7 @@ venv.bak/ # VSCode .vscode +.vs # In case you place source code in ~/nni/ /experiments diff --git a/examples/nas/.gitignore b/examples/nas/.gitignore index 1269488f7f..8705cba4d6 100644 --- a/examples/nas/.gitignore +++ b/examples/nas/.gitignore @@ -1 +1 @@ -data +data diff --git a/examples/nas/darts/model.py b/examples/nas/darts/model.py deleted file mode 100644 index 629831e0b7..0000000000 --- a/examples/nas/darts/model.py +++ /dev/null @@ -1,141 +0,0 @@ -import torch -import torch.nn as nn - -import ops -from nni.nas import pytorch as nas - - -class SearchCell(nn.Module): - """ - Cell for search. - """ - - def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): - """ - Initialization a search cell. - - Parameters - ---------- - n_nodes: int - Number of nodes in current DAG. - channels_pp: int - Number of output channels from previous previous cell. - channels_p: int - Number of output channels from previous cell. - channels: int - Number of channels that will be used in the current DAG. - reduction_p: bool - Flag for whether the previous cell is reduction cell or not. - reduction: bool - Flag for whether the current cell is reduction cell or not. - """ - super().__init__() - self.reduction = reduction - self.n_nodes = n_nodes - - # If previous cell is reduction cell, current input size does not match with - # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. - if reduction_p: - self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False) - else: - self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False) - self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False) - - # generate dag - self.mutable_ops = nn.ModuleList() - for depth in range(self.n_nodes): - self.mutable_ops.append(nn.ModuleList()) - for i in range(2 + depth): # include 2 input nodes - # reduction should be used only for input node - stride = 2 if reduction and i < 2 else 1 - op = nas.mutables.LayerChoice([ops.PoolBN('max', channels, 3, stride, 1, affine=False), - ops.PoolBN('avg', channels, 3, stride, 1, affine=False), - ops.Identity() if stride == 1 else - ops.FactorizedReduce(channels, channels, affine=False), - ops.SepConv(channels, channels, 3, stride, 1, affine=False), - ops.SepConv(channels, channels, 5, stride, 2, affine=False), - ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False), - ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False), - ops.Zero(stride)], - key="r{}_d{}_i{}".format(reduction, depth, i)) - self.mutable_ops[depth].append(op) - - def forward(self, s0, s1): - # s0, s1 are the outputs of previous previous cell and previous cell, respectively. - tensors = [self.preproc0(s0), self.preproc1(s1)] - for ops in self.mutable_ops: - assert len(ops) == len(tensors) - cur_tensor = sum(op(tensor) for op, tensor in zip(ops, tensors)) - tensors.append(cur_tensor) - - output = torch.cat(tensors[2:], dim=1) - return output - - -class SearchCNN(nn.Module): - """ - Search CNN model - """ - - def __init__(self, in_channels, channels, n_classes, n_layers, n_nodes=4, stem_multiplier=3): - """ - Initializing a search channelsNN. - - Parameters - ---------- - in_channels: int - Number of channels in images. - channels: int - Number of channels used in the network. - n_classes: int - Number of classes. - n_layers: int - Number of cells in the whole network. - n_nodes: int - Number of nodes in a cell. - stem_multiplier: int - Multiplier of channels in STEM. - """ - super().__init__() - self.in_channels = in_channels - self.channels = channels - self.n_classes = n_classes - self.n_layers = n_layers - - c_cur = stem_multiplier * self.channels - self.stem = nn.Sequential( - nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), - nn.BatchNorm2d(c_cur) - ) - - # for the first cell, stem is used for both s0 and s1 - # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. - channels_pp, channels_p, c_cur = c_cur, c_cur, channels - - self.cells = nn.ModuleList() - reduction_p, reduction = False, False - for i in range(n_layers): - reduction_p, reduction = reduction, False - # Reduce featuremap size and double channels in 1/3 and 2/3 layer. - if i in [n_layers // 3, 2 * n_layers // 3]: - c_cur *= 2 - reduction = True - - cell = SearchCell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction) - self.cells.append(cell) - c_cur_out = c_cur * n_nodes - channels_pp, channels_p = channels_p, c_cur_out - - self.gap = nn.AdaptiveAvgPool2d(1) - self.linear = nn.Linear(channels_p, n_classes) - - def forward(self, x): - s0 = s1 = self.stem(x) - - for cell in self.cells: - s0, s1 = s1, cell(s0, s1) - - out = self.gap(s1) - out = out.view(out.size(0), -1) # flatten - logits = self.linear(out) - return logits diff --git a/examples/nas/darts/search.py b/examples/nas/darts/search.py index ad0650d156..0d7f995769 100644 --- a/examples/nas/darts/search.py +++ b/examples/nas/darts/search.py @@ -1,25 +1,23 @@ from argparse import ArgumentParser -import datasets import torch import torch.nn as nn -from model import SearchCNN -from nni.nas.pytorch.darts import DartsTrainer +import datasets +from nni.nas.pytorch.darts import CnnNetwork, DartsTrainer from utils import accuracy - if __name__ == "__main__": parser = ArgumentParser("darts") - parser.add_argument("--layers", default=4, type=int) - parser.add_argument("--nodes", default=2, type=int) + parser.add_argument("--layers", default=5, type=int) + parser.add_argument("--nodes", default=4, type=int) parser.add_argument("--batch-size", default=128, type=int) parser.add_argument("--log-frequency", default=1, type=int) args = parser.parse_args() dataset_train, dataset_valid = datasets.get_dataset("cifar10") - model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes) + model = CnnNetwork(3, 16, 10, args.layers, n_nodes=args.nodes) criterion = nn.CrossEntropyLoss() optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) diff --git a/examples/nas/pdarts/.gitignore b/examples/nas/pdarts/.gitignore new file mode 100644 index 0000000000..054c274eeb --- /dev/null +++ b/examples/nas/pdarts/.gitignore @@ -0,0 +1,2 @@ +data/* +log diff --git a/examples/nas/pdarts/datasets.py b/examples/nas/pdarts/datasets.py new file mode 100644 index 0000000000..8fe0ab0fbf --- /dev/null +++ b/examples/nas/pdarts/datasets.py @@ -0,0 +1,25 @@ +from torchvision import transforms +from torchvision.datasets import CIFAR10 + + +def get_dataset(cls): + MEAN = [0.49139968, 0.48215827, 0.44653124] + STD = [0.24703233, 0.24348505, 0.26158768] + transf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip() + ] + normalize = [ + transforms.ToTensor(), + transforms.Normalize(MEAN, STD) + ] + + train_transform = transforms.Compose(transf + normalize) + valid_transform = transforms.Compose(normalize) + + if cls == "cifar10": + dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform) + dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform) + else: + raise NotImplementedError + return dataset_train, dataset_valid diff --git a/examples/nas/pdarts/main.py b/examples/nas/pdarts/main.py new file mode 100644 index 0000000000..68a59c8856 --- /dev/null +++ b/examples/nas/pdarts/main.py @@ -0,0 +1,65 @@ +from argparse import ArgumentParser + +import datasets +import torch +import torch.nn as nn +import nni.nas.pytorch as nas +from nni.nas.pytorch.pdarts import PdartsTrainer +from nni.nas.pytorch.darts import CnnNetwork, CnnCell + + +def accuracy(output, target, topk=(1,)): + """ Computes the precision@k for the specified values of k """ + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + # one-hot case + if target.ndimension() > 1: + target = target.max(1)[1] + + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = dict() + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() + return res + + +if __name__ == "__main__": + parser = ArgumentParser("darts") + parser.add_argument("--layers", default=5, type=int) + parser.add_argument('--add_layers', action='append', + default=[0, 6, 12], help='add layers') + parser.add_argument("--nodes", default=4, type=int) + parser.add_argument("--batch-size", default=128, type=int) + parser.add_argument("--log-frequency", default=1, type=int) + args = parser.parse_args() + + dataset_train, dataset_valid = datasets.get_dataset("cifar10") + + def model_creator(layers, n_nodes): + model = CnnNetwork(3, 16, 10, layers, n_nodes=n_nodes, cell_type=CnnCell) + loss = nn.CrossEntropyLoss() + + model_optim = torch.optim.SGD(model.parameters(), 0.025, + momentum=0.9, weight_decay=3.0E-4) + n_epochs = 50 + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, n_epochs, eta_min=0.001) + return model, loss, model_optim, lr_scheduler + + trainer = PdartsTrainer(model_creator, + metrics=lambda output, target: accuracy(output, target, topk=(1,)), + num_epochs=50, + pdarts_num_layers=[0, 6, 12], + pdarts_num_to_drop=[3, 2, 2], + dataset_train=dataset_train, + dataset_valid=dataset_valid, + layers=args.layers, + n_nodes=args.nodes, + batch_size=args.batch_size, + log_frequency=args.log_frequency) + trainer.train() + trainer.export() diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py b/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py index 34e5f6e81c..f28d2cd73c 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py @@ -1,2 +1,4 @@ from .mutator import DartsMutator from .trainer import DartsTrainer +from .cnn_cell import CnnCell +from .cnn_network import CnnNetwork diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py new file mode 100644 index 0000000000..69dc28e8f0 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py @@ -0,0 +1,69 @@ + +import torch +import torch.nn as nn + +import nni.nas.pytorch as nas +from nni.nas.pytorch.modules import RankedModule + +from .cnn_ops import OPS, PRIMITIVES, FactorizedReduce, StdConv + + +class CnnCell(RankedModule): + """ + Cell for search. + """ + + def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): + """ + Initialization a search cell. + + Parameters + ---------- + n_nodes: int + Number of nodes in current DAG. + channels_pp: int + Number of output channels from previous previous cell. + channels_p: int + Number of output channels from previous cell. + channels: int + Number of channels that will be used in the current DAG. + reduction_p: bool + Flag for whether the previous cell is reduction cell or not. + reduction: bool + Flag for whether the current cell is reduction cell or not. + """ + super(CnnCell, self).__init__(rank=1, reduction=reduction) + self.n_nodes = n_nodes + + # If previous cell is reduction cell, current input size does not match with + # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. + if reduction_p: + self.preproc0 = FactorizedReduce(channels_pp, channels, affine=False) + else: + self.preproc0 = StdConv(channels_pp, channels, 1, 1, 0, affine=False) + self.preproc1 = StdConv(channels_p, channels, 1, 1, 0, affine=False) + + # generate dag + self.mutable_ops = nn.ModuleList() + for depth in range(self.n_nodes): + self.mutable_ops.append(nn.ModuleList()) + for i in range(2 + depth): # include 2 input nodes + # reduction should be used only for input node + stride = 2 if reduction and i < 2 else 1 + m_ops = [] + for primitive in PRIMITIVES: + op = OPS[primitive](channels, stride, False) + m_ops.append(op) + op = nas.mutables.LayerChoice(m_ops, key="r{}_d{}_i{}".format(reduction, depth, i)) + self.mutable_ops[depth].append(op) + + def forward(self, s0, s1): + # s0, s1 are the outputs of previous previous cell and previous cell, respectively. + tensors = [self.preproc0(s0), self.preproc1(s1)] + for ops in self.mutable_ops: + assert len(ops) == len(tensors) + cur_tensor = sum(op(tensor) for op, tensor in zip(ops, tensors)) + tensors.append(cur_tensor) + + output = torch.cat(tensors[2:], dim=1) + return output diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py new file mode 100644 index 0000000000..d126e3353e --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py @@ -0,0 +1,73 @@ + +import torch.nn as nn + +from .cnn_cell import CnnCell + + +class CnnNetwork(nn.Module): + """ + Search CNN model + """ + + def __init__(self, in_channels, channels, n_classes, n_layers, n_nodes=4, stem_multiplier=3, cell_type=CnnCell): + """ + Initializing a search channelsNN. + + Parameters + ---------- + in_channels: int + Number of channels in images. + channels: int + Number of channels used in the network. + n_classes: int + Number of classes. + n_layers: int + Number of cells in the whole network. + n_nodes: int + Number of nodes in a cell. + stem_multiplier: int + Multiplier of channels in STEM. + """ + super().__init__() + self.in_channels = in_channels + self.channels = channels + self.n_classes = n_classes + self.n_layers = n_layers + + c_cur = stem_multiplier * self.channels + self.stem = nn.Sequential( + nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), + nn.BatchNorm2d(c_cur) + ) + + # for the first cell, stem is used for both s0 and s1 + # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. + channels_pp, channels_p, c_cur = c_cur, c_cur, channels + + self.cells = nn.ModuleList() + reduction_p, reduction = False, False + for i in range(n_layers): + reduction_p, reduction = reduction, False + # Reduce featuremap size and double channels in 1/3 and 2/3 layer. + if i in [n_layers // 3, 2 * n_layers // 3]: + c_cur *= 2 + reduction = True + + cell = cell_type(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction) + self.cells.append(cell) + c_cur_out = c_cur * n_nodes + channels_pp, channels_p = channels_p, c_cur_out + + self.gap = nn.AdaptiveAvgPool2d(1) + self.linear = nn.Linear(channels_p, n_classes) + + def forward(self, x): + s0 = s1 = self.stem(x) + + for cell in self.cells: + s0, s1 = s1, cell(s0, s1) + + out = self.gap(s1) + out = out.view(out.size(0), -1) # flatten + logits = self.linear(out) + return logits diff --git a/examples/nas/darts/ops.py b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py similarity index 93% rename from examples/nas/darts/ops.py rename to src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py index ef25a6e830..02b4a3a94c 100644 --- a/examples/nas/darts/ops.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py @@ -1,29 +1,27 @@ import torch import torch.nn as nn - PRIMITIVES = [ + 'none', 'max_pool_3x3', 'avg_pool_3x3', - 'skip_connect', # identity + 'skip_connect', # identity 'sep_conv_3x3', 'sep_conv_5x5', 'dil_conv_3x3', 'dil_conv_5x5', - 'none' ] OPS = { 'none': lambda C, stride, affine: Zero(stride), 'avg_pool_3x3': lambda C, stride, affine: PoolBN('avg', C, 3, stride, 1, affine=affine), 'max_pool_3x3': lambda C, stride, affine: PoolBN('max', C, 3, stride, 1, affine=affine), - 'skip_connect': lambda C, stride, affine: \ - Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), + 'skip_connect': lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), 'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), 'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), 'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), - 'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5 - 'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9 + 'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5 + 'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9 'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine) } @@ -60,6 +58,7 @@ class PoolBN(nn.Module): """ AvgPool or MaxPool - BN """ + def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True): """ Args: @@ -85,6 +84,7 @@ class StdConv(nn.Module): """ Standard conv ReLU - Conv - BN """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): super().__init__() self.net = nn.Sequential( @@ -101,6 +101,7 @@ class FacConv(nn.Module): """ Factorized conv ReLU - Conv(Kx1) - Conv(1xK) - BN """ + def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True): super().__init__() self.net = nn.Sequential( @@ -118,14 +119,14 @@ class DilConv(nn.Module): """ (Dilated) depthwise separable conv ReLU - (Dilated) depthwise separable - Pointwise - BN If dilation == 2, 3x3 conv => 5x5 receptive field - 5x5 conv => 9x9 receptive field + 5x5 conv => 9x9 receptive field """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): super().__init__() self.net = nn.Sequential( nn.ReLU(), - nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, - bias=False), + nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, bias=False), nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), nn.BatchNorm2d(C_out, affine=affine) ) @@ -138,6 +139,7 @@ class SepConv(nn.Module): """ Depthwise separable conv DilConv(dilation=1) * 2 """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): super().__init__() self.net = nn.Sequential( @@ -172,6 +174,7 @@ class FactorizedReduce(nn.Module): """ Reduce feature map size by factorized pointwise(stride=2). """ + def __init__(self, C_in, C_out, affine=True): super().__init__() self.relu = nn.ReLU() diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py index 72ac427c11..75463ff23f 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py @@ -94,7 +94,8 @@ def validate_epoch(self, epoch): with torch.no_grad(): for step, (X, y) in enumerate(self.valid_loader): X, y = X.to(self.device), y.to(self.device) - logits = self.model(X) + with self.mutator.forward_pass(): + logits = self.model(X) metrics = self.metrics(logits, y) meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py index 93dad9c77c..a158886233 100644 --- a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py @@ -40,7 +40,7 @@ def before_build(self, model): self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.v_attn = nn.Linear(self.lstm_size, 1, bias=False) self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1) - self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) + self.skip_targets = nn.Parameter(torch.Tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) self.cross_entropy_loss = nn.CrossEntropyLoss() def after_build(self, model): @@ -79,7 +79,7 @@ def on_calc_layer_choice_mask(self, mutable): self._lstm_next_step() logit = self.soft(self._h[-1]) if self.tanh_constant is not None: - logit = self.tanh_constant * torch.tanh(logit) + logit = self.tanh_constant * torch.tanh(logit) branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) log_prob = self.cross_entropy_loss(logit, branch_id) self.sample_log_prob += log_prob diff --git a/src/sdk/pynni/nni/nas/pytorch/modules.py b/src/sdk/pynni/nni/nas/pytorch/modules.py new file mode 100644 index 0000000000..6570220e13 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/modules.py @@ -0,0 +1,9 @@ + +from torch import nn as nn + + +class RankedModule(nn.Module): + def __init__(self, rank=None, reduction=False): + super(RankedModule, self).__init__() + self.rank = rank + self.reduction = reduction diff --git a/src/sdk/pynni/nni/nas/pytorch/mutables.py b/src/sdk/pynni/nni/nas/pytorch/mutables.py index 456fe7a498..e28af84037 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutables.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutables.py @@ -56,9 +56,6 @@ def _check_built(self): "Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__" "so that trainer can locate all your mutables. See NNI docs for more details.".format(self)) - def __repr__(self): - return "{} ({})".format(self.name, self.key) - class MutableScope(PyTorchMutable): """ @@ -85,6 +82,9 @@ def __init__(self, op_candidates, reduction="mean", return_mask=False, key=None) self.reduction = reduction self.return_mask = return_mask + def __len__(self): + return self.length + def forward(self, *inputs): out, mask = self.mutator.on_forward(self, *inputs) if self.return_mask: @@ -116,4 +116,4 @@ def forward(self, optional_inputs, semantic_labels=None): def similar(self, other): return type(self) == type(other) and \ - self.n_candidates == other.n_candidates and self.n_selected and other.n_selected + self.n_candidates == other.n_candidates and self.n_selected and other.n_selected diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py new file mode 100644 index 0000000000..27dd912ab3 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py @@ -0,0 +1 @@ +from .trainer import PdartsTrainer diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py new file mode 100644 index 0000000000..6e385b1170 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py @@ -0,0 +1,93 @@ +import copy + +import numpy as np +import torch +from torch import nn as nn +from torch.nn import functional as F + +from nni.nas.pytorch.darts import DartsMutator +from nni.nas.pytorch.mutables import LayerChoice + + +class PdartsMutator(DartsMutator): + + def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches=None): + self.pdarts_epoch_index = pdarts_epoch_index + self.pdarts_num_to_drop = pdarts_num_to_drop + self.switches = switches + + super(PdartsMutator, self).__init__(model) + + def before_build(self, model): + self.choices = nn.ParameterDict() + if self.switches is None: + self.switches = {} + + def named_mutables(self, model): + key2module = dict() + for name, module in model.named_modules(): + if isinstance(module, LayerChoice): + key2module[module.key] = module + yield name, module, True + + def drop_paths(self): + for key in self.switches: + prob = F.softmax(self.choices[key], dim=-1).data.cpu().numpy() + + switches = self.switches[key] + idxs = [] + for j in range(len(switches)): + if switches[j]: + idxs.append(j) + if self.pdarts_epoch_index == len(self.pdarts_num_to_drop) - 1: + # for the last stage, drop all Zero operations + drop = self.get_min_k_no_zero(prob, idxs, self.pdarts_num_to_drop[self.pdarts_epoch_index]) + else: + drop = self.get_min_k(prob, self.pdarts_num_to_drop[self.pdarts_epoch_index]) + + for idx in drop: + switches[idxs[idx]] = False + return self.switches + + def on_init_layer_choice(self, mutable: LayerChoice): + switches = self.switches.get( + mutable.key, [True for j in range(mutable.length)]) + + for index in range(len(switches)-1, -1, -1): + if switches[index] == False: + del(mutable.choices[index]) + mutable.length -= 1 + + self.switches[mutable.key] = switches + + self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length)) + + def on_calc_layer_choice_mask(self, mutable: LayerChoice): + return F.softmax(self.choices[mutable.key], dim=-1) + + def get_min_k(self, input_in, k): + index = [] + for _ in range(k): + idx = np.argmin(input) + index.append(idx) + + return index + + def get_min_k_no_zero(self, w_in, idxs, k): + w = copy.deepcopy(w_in) + index = [] + if 0 in idxs: + zf = True + else: + zf = False + if zf: + w = w[1:] + index.append(0) + k = k - 1 + for _ in range(k): + idx = np.argmin(w) + w[idx] = 1 + if zf: + idx = idx + 1 + index.append(idx) + return index diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py new file mode 100644 index 0000000000..6425e234d8 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py @@ -0,0 +1,54 @@ +from nni.nas.pytorch.darts import DartsTrainer +from nni.nas.pytorch.trainer import Trainer + +from .mutator import PdartsMutator + + +class PdartsTrainer(Trainer): + + def __init__(self, model_creator, metrics, num_epochs, dataset_train, dataset_valid, + layers=5, n_nodes=4, pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 2], + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None): + self.model_creator = model_creator + self.layers = layers + self.n_nodes = n_nodes + self.pdarts_num_layers = pdarts_num_layers + self.pdarts_num_to_drop = pdarts_num_to_drop + self.pdarts_epoch = len(pdarts_num_to_drop) + self.darts_parameters = { + "metrics": metrics, + "num_epochs": num_epochs, + "dataset_train": dataset_train, + "dataset_valid": dataset_valid, + "batch_size": batch_size, + "workers": workers, + "device": device, + "log_frequency": log_frequency + } + + def train(self): + layers = self.layers + n_nodes = self.n_nodes + switches = None + for epoch in range(self.pdarts_epoch): + + layers = self.layers+self.pdarts_num_layers[epoch] + model, loss, model_optim, lr_scheduler = self.model_creator( + layers, n_nodes) + mutator = PdartsMutator( + model, epoch, self.pdarts_num_to_drop, switches) + + self.trainer = DartsTrainer(model, loss=loss, model_optim=model_optim, + lr_scheduler=lr_scheduler, mutator=mutator, **self.darts_parameters) + print("start pdrats training %s..." % epoch) + + self.trainer.train() + + # with open('log/parameters_%d.txt' % epoch, "w") as f: + # f.write(str(model.parameters)) + + switches = mutator.drop_paths() + + def export(self): + if (self.trainer is not None) and hasattr(self.trainer, "export"): + self.trainer.export() From 1cada380ff768d6e59aa4089734cbed74014b9bb Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 18 Nov 2019 14:33:29 +0800 Subject: [PATCH 05/10] Extract base mutator/trainer and support ENAS micro search space (#1739) --- examples/nas/darts/datasets.py | 32 ++- examples/nas/darts/model.py | 149 ++++++++++++++ examples/nas/darts/ops.py | 135 +++++++++++++ examples/nas/darts/search.py | 34 ++-- examples/nas/enas/enas_ops.py | 80 -------- examples/nas/enas/macro.py | 95 ++------- examples/nas/enas/micro.py | 183 ++++++++++++++++++ examples/nas/enas/ops.py | 35 +++- examples/nas/enas/search.py | 47 +++++ examples/nas/enas/utils.py | 27 +++ src/__init__.py | 0 src/sdk/pynni/nni/nas/pytorch/base_mutator.py | 70 +++++++ src/sdk/pynni/nni/nas/pytorch/base_trainer.py | 16 ++ src/sdk/pynni/nni/nas/pytorch/callbacks.py | 69 +++++++ .../pynni/nni/nas/pytorch/darts/__init__.py | 3 +- .../pynni/nni/nas/pytorch/darts/mutator.py | 32 ++- src/sdk/pynni/nni/nas/pytorch/darts/scope.py | 11 ++ .../pynni/nni/nas/pytorch/darts/trainer.py | 57 ++---- src/sdk/pynni/nni/nas/pytorch/enas/mutator.py | 105 +++++----- src/sdk/pynni/nni/nas/pytorch/enas/trainer.py | 144 +++++++------- src/sdk/pynni/nni/nas/pytorch/fixed.py | 58 ++++++ src/sdk/pynni/nni/nas/pytorch/mutables.py | 58 ++++-- src/sdk/pynni/nni/nas/pytorch/mutator.py | 115 +++-------- .../pynni/nni/nas/pytorch/pdarts/trainer.py | 6 +- src/sdk/pynni/nni/nas/pytorch/trainer.py | 65 ++++++- src/sdk/pynni/nni/nas/utils.py | 11 -- 26 files changed, 1158 insertions(+), 479 deletions(-) create mode 100644 examples/nas/darts/model.py create mode 100644 examples/nas/darts/ops.py delete mode 100644 examples/nas/enas/enas_ops.py create mode 100644 examples/nas/enas/micro.py create mode 100644 examples/nas/enas/search.py create mode 100644 examples/nas/enas/utils.py delete mode 100644 src/__init__.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/base_mutator.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/base_trainer.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/callbacks.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/darts/scope.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/fixed.py diff --git a/examples/nas/darts/datasets.py b/examples/nas/darts/datasets.py index 8fe0ab0fbf..c5861f16d3 100644 --- a/examples/nas/darts/datasets.py +++ b/examples/nas/darts/datasets.py @@ -1,8 +1,33 @@ +import numpy as np +import torch from torchvision import transforms from torchvision.datasets import CIFAR10 -def get_dataset(cls): +class Cutout(object): + def __init__(self, length): + self.length = length + + def __call__(self, img): + h, w = img.size(1), img.size(2) + mask = np.ones((h, w), np.float32) + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1: y2, x1: x2] = 0. + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img *= mask + + return img + + +def get_dataset(cls, cutout_length=0): MEAN = [0.49139968, 0.48215827, 0.44653124] STD = [0.24703233, 0.24348505, 0.26158768] transf = [ @@ -13,8 +38,11 @@ def get_dataset(cls): transforms.ToTensor(), transforms.Normalize(MEAN, STD) ] + cutout = [] + if cutout_length > 0: + cutout.append(Cutout(cutout_length)) - train_transform = transforms.Compose(transf + normalize) + train_transform = transforms.Compose(transf + normalize + cutout) valid_transform = transforms.Compose(normalize) if cls == "cifar10": diff --git a/examples/nas/darts/model.py b/examples/nas/darts/model.py new file mode 100644 index 0000000000..5c284b5a46 --- /dev/null +++ b/examples/nas/darts/model.py @@ -0,0 +1,149 @@ +import torch +import torch.nn as nn + +import ops +from nni.nas.pytorch import mutables, darts + + +class AuxiliaryHead(nn.Module): + """ Auxiliary head in 2/3 place of network to let the gradient flow well """ + + def __init__(self, input_size, C, n_classes): + """ assuming input size 7x7 or 8x8 """ + assert input_size in [7, 8] + super().__init__() + self.net = nn.Sequential( + nn.ReLU(inplace=True), + nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out + nn.Conv2d(C, 128, kernel_size=1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out + nn.BatchNorm2d(768), + nn.ReLU(inplace=True) + ) + self.linear = nn.Linear(768, n_classes) + + def forward(self, x): + out = self.net(x) + out = out.view(out.size(0), -1) # flatten + logits = self.linear(out) + return logits + + +class Node(darts.DartsNode): + def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect, drop_path_prob=0.): + super().__init__(node_id, limitation=2) + self.ops = nn.ModuleList() + for i in range(num_prev_nodes): + stride = 2 if i < num_downsample_connect else 1 + self.ops.append( + mutables.LayerChoice( + [ + ops.PoolBN('max', channels, 3, stride, 1, affine=False), + ops.PoolBN('avg', channels, 3, stride, 1, affine=False), + nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False), + ops.SepConv(channels, channels, 3, stride, 1, affine=False), + ops.SepConv(channels, channels, 5, stride, 2, affine=False), + ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False), + ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False), + ], + key="{}_p{}".format(node_id, i))) + self.drop_path = ops.DropPath_(drop_path_prob) + + def forward(self, prev_nodes): + assert len(self.ops) == len(prev_nodes) + out = [op(node) for op, node in zip(self.ops, prev_nodes)] + return sum(self.drop_path(o) for o in out if o is not None) + + +class Cell(nn.Module): + + def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction, drop_path_prob=0.): + super().__init__() + self.reduction = reduction + self.n_nodes = n_nodes + + # If previous cell is reduction cell, current input size does not match with + # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. + if reduction_p: + self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False) + else: + self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False) + self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False) + + # generate dag + self.mutable_ops = nn.ModuleList() + for depth in range(self.n_nodes): + self.mutable_ops.append(Node("r{:d}_n{}".format(reduction, depth), + depth + 2, channels, 2 if reduction else 0, + drop_path_prob=drop_path_prob)) + + def forward(self, s0, s1): + # s0, s1 are the outputs of previous previous cell and previous cell, respectively. + tensors = [self.preproc0(s0), self.preproc1(s1)] + for node in self.mutable_ops: + cur_tensor = node(tensors) + tensors.append(cur_tensor) + + output = torch.cat(tensors[2:], dim=1) + return output + + +class CNN(nn.Module): + + def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4, + stem_multiplier=3, auxiliary=False, drop_path_prob=0.): + super().__init__() + self.in_channels = in_channels + self.channels = channels + self.n_classes = n_classes + self.n_layers = n_layers + self.aux_pos = 2 * n_layers // 3 if auxiliary else -1 + + c_cur = stem_multiplier * self.channels + self.stem = nn.Sequential( + nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), + nn.BatchNorm2d(c_cur) + ) + + # for the first cell, stem is used for both s0 and s1 + # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. + channels_pp, channels_p, c_cur = c_cur, c_cur, channels + + self.cells = nn.ModuleList() + reduction_p, reduction = False, False + for i in range(n_layers): + reduction_p, reduction = reduction, False + # Reduce featuremap size and double channels in 1/3 and 2/3 layer. + if i in [n_layers // 3, 2 * n_layers // 3]: + c_cur *= 2 + reduction = True + + cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction, drop_path_prob=drop_path_prob) + self.cells.append(cell) + c_cur_out = c_cur * n_nodes + channels_pp, channels_p = channels_p, c_cur_out + + if i == self.aux_pos: + self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes) + + self.gap = nn.AdaptiveAvgPool2d(1) + self.linear = nn.Linear(channels_p, n_classes) + + def forward(self, x): + s0 = s1 = self.stem(x) + + aux_logits = None + for i, cell in enumerate(self.cells): + s0, s1 = s1, cell(s0, s1) + if i == self.aux_pos and self.training: + aux_logits = self.aux_head(s1) + + out = self.gap(s1) + out = out.view(out.size(0), -1) # flatten + logits = self.linear(out) + + if aux_logits is not None: + return logits, aux_logits + return logits diff --git a/examples/nas/darts/ops.py b/examples/nas/darts/ops.py new file mode 100644 index 0000000000..2fef9fec19 --- /dev/null +++ b/examples/nas/darts/ops.py @@ -0,0 +1,135 @@ +import torch +import torch.nn as nn + + +class DropPath_(nn.Module): + def __init__(self, p=0.): + """ [!] DropPath is inplace module + Args: + p: probability of an path to be zeroed. + """ + super().__init__() + self.p = p + + def extra_repr(self): + return 'p={}, inplace'.format(self.p) + + def forward(self, x): + if self.training and self.p > 0.: + keep_prob = 1. - self.p + # per data point mask + mask = torch.zeros((x.size(0), 1, 1, 1), device=x.device).bernoulli_(keep_prob) + x.div_(keep_prob).mul_(mask) + + return x + + +class PoolBN(nn.Module): + """ + AvgPool or MaxPool - BN + """ + def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True): + """ + Args: + pool_type: 'max' or 'avg' + """ + super().__init__() + if pool_type.lower() == 'max': + self.pool = nn.MaxPool2d(kernel_size, stride, padding) + elif pool_type.lower() == 'avg': + self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) + else: + raise ValueError() + + self.bn = nn.BatchNorm2d(C, affine=affine) + + def forward(self, x): + out = self.pool(x) + out = self.bn(out) + return out + + +class StdConv(nn.Module): + """ Standard conv + ReLU - Conv - BN + """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): + super().__init__() + self.net = nn.Sequential( + nn.ReLU(), + nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False), + nn.BatchNorm2d(C_out, affine=affine) + ) + + def forward(self, x): + return self.net(x) + + +class FacConv(nn.Module): + """ Factorized conv + ReLU - Conv(Kx1) - Conv(1xK) - BN + """ + def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True): + super().__init__() + self.net = nn.Sequential( + nn.ReLU(), + nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False), + nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False), + nn.BatchNorm2d(C_out, affine=affine) + ) + + def forward(self, x): + return self.net(x) + + +class DilConv(nn.Module): + """ (Dilated) depthwise separable conv + ReLU - (Dilated) depthwise separable - Pointwise - BN + If dilation == 2, 3x3 conv => 5x5 receptive field + 5x5 conv => 9x9 receptive field + """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): + super().__init__() + self.net = nn.Sequential( + nn.ReLU(), + nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, + bias=False), + nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(C_out, affine=affine) + ) + + def forward(self, x): + return self.net(x) + + +class SepConv(nn.Module): + """ Depthwise separable conv + DilConv(dilation=1) * 2 + """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): + super().__init__() + self.net = nn.Sequential( + DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine), + DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine) + ) + + def forward(self, x): + return self.net(x) + + +class FactorizedReduce(nn.Module): + """ + Reduce feature map size by factorized pointwise(stride=2). + """ + def __init__(self, C_in, C_out, affine=True): + super().__init__() + self.relu = nn.ReLU() + self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.bn = nn.BatchNorm2d(C_out, affine=affine) + + def forward(self, x): + x = self.relu(x) + out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) + out = self.bn(out) + return out diff --git a/examples/nas/darts/search.py b/examples/nas/darts/search.py index 0d7f995769..75773cf5e0 100644 --- a/examples/nas/darts/search.py +++ b/examples/nas/darts/search.py @@ -1,41 +1,39 @@ from argparse import ArgumentParser +import datasets import torch import torch.nn as nn -import datasets -from nni.nas.pytorch.darts import CnnNetwork, DartsTrainer +from model import CNN +from nni.nas.pytorch.callbacks import LearningRateScheduler, ArchitectureCheckpoint +from nni.nas.pytorch.darts import DartsTrainer from utils import accuracy + if __name__ == "__main__": parser = ArgumentParser("darts") - parser.add_argument("--layers", default=5, type=int) - parser.add_argument("--nodes", default=4, type=int) - parser.add_argument("--batch-size", default=128, type=int) - parser.add_argument("--log-frequency", default=1, type=int) + parser.add_argument("--layers", default=8, type=int) + parser.add_argument("--batch-size", default=96, type=int) + parser.add_argument("--log-frequency", default=10, type=int) + parser.add_argument("--epochs", default=50, type=int) args = parser.parse_args() dataset_train, dataset_valid = datasets.get_dataset("cifar10") - model = CnnNetwork(3, 16, 10, args.layers, n_nodes=args.nodes) + model = CNN(32, 3, 16, 10, args.layers) criterion = nn.CrossEntropyLoss() optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) - n_epochs = 50 - lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, n_epochs, eta_min=0.001) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001) trainer = DartsTrainer(model, loss=criterion, metrics=lambda output, target: accuracy(output, target, topk=(1,)), - model_optim=optim, - lr_scheduler=lr_scheduler, - num_epochs=50, + optimizer=optim, + num_epochs=args.epochs, dataset_train=dataset_train, dataset_valid=dataset_valid, batch_size=args.batch_size, - log_frequency=args.log_frequency) - trainer.train() - trainer.export() - -# augment step -# ... + log_frequency=args.log_frequency, + callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")]) + trainer.train_and_validate() diff --git a/examples/nas/enas/enas_ops.py b/examples/nas/enas/enas_ops.py deleted file mode 100644 index 2df9088321..0000000000 --- a/examples/nas/enas/enas_ops.py +++ /dev/null @@ -1,80 +0,0 @@ -import torch -import torch.nn as nn - - -class StdConv(nn.Module): - def __init__(self, C_in, C_out): - super(StdConv, self).__init__() - self.conv = nn.Sequential( - nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), - nn.BatchNorm2d(C_out, affine=False), - nn.ReLU() - ) - - def forward(self, x): - return self.conv(x) - - -class PoolBranch(nn.Module): - def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False): - super().__init__() - self.preproc = StdConv(C_in, C_out) - if pool_type.lower() == 'max': - self.pool = nn.MaxPool2d(kernel_size, stride, padding) - elif pool_type.lower() == 'avg': - self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) - else: - raise ValueError() - self.bn = nn.BatchNorm2d(C_out, affine=affine) - - def forward(self, x): - out = self.preproc(x) - out = self.pool(out) - out = self.bn(out) - return out - - -class SeparableConv(nn.Module): - def __init__(self, C_in, C_out, kernel_size, stride, padding): - super(SeparableConv, self).__init__() - self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride, - groups=C_in, bias=False) - self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False) - - def forward(self, x): - out = self.depthwise(x) - out = self.pointwise(out) - return out - - -class ConvBranch(nn.Module): - def __init__(self, C_in, C_out, kernel_size, stride, padding, separable): - super(ConvBranch, self).__init__() - self.preproc = StdConv(C_in, C_out) - if separable: - self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding) - else: - self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding) - self.postproc = nn.Sequential( - nn.BatchNorm2d(C_out, affine=False), - nn.ReLU() - ) - - def forward(self, x): - out = self.preproc(x) - out = self.conv(out) - out = self.postproc(out) - return out - - -class FactorizedReduce(nn.Module): - def __init__(self, C_in, C_out, affine=False): - super().__init__() - self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) - self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) - self.bn = nn.BatchNorm2d(C_out, affine=affine) - - def forward(self, x): - out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) - out = self.bn(out) - return out diff --git a/examples/nas/enas/macro.py b/examples/nas/enas/macro.py index 8d2ca21522..48fcaaf03d 100644 --- a/examples/nas/enas/macro.py +++ b/examples/nas/enas/macro.py @@ -1,16 +1,13 @@ -from argparse import ArgumentParser -import torch import torch.nn as nn -import datasets +from nni.nas.pytorch import mutables from ops import FactorizedReduce, ConvBranch, PoolBranch -from nni.nas.pytorch import mutables, enas -class ENASLayer(nn.Module): +class ENASLayer(mutables.MutableScope): - def __init__(self, layer_id, in_filters, out_filters): - super().__init__() + def __init__(self, key, num_prev_layers, in_filters, out_filters): + super().__init__(key) self.in_filters = in_filters self.out_filters = out_filters self.mutable = mutables.LayerChoice([ @@ -21,22 +18,19 @@ def __init__(self, layer_id, in_filters, out_filters): PoolBranch('avg', in_filters, out_filters, 3, 1, 1), PoolBranch('max', in_filters, out_filters, 3, 1, 1) ]) - if layer_id > 0: - self.skipconnect = mutables.InputChoice(layer_id, n_selected=None, reduction="sum") + if num_prev_layers > 0: + self.skipconnect = mutables.InputChoice(num_prev_layers, n_selected=None, reduction="sum") else: self.skipconnect = None self.batch_norm = nn.BatchNorm2d(out_filters, affine=False) - self.mutable_scope = mutables.MutableScope("layer_{}".format(layer_id)) - def forward(self, prev_layers): - with self.mutable_scope: - out = self.mutable(prev_layers[-1]) - if self.skipconnect is not None: - connection = self.skipconnect(prev_layers[:-1], - ["layer_{}".format(i) for i in range(len(prev_layers) - 1)]) - if connection is not None: - out += connection - return self.batch_norm(out) + def forward(self, prev_layers, prev_labels): + out = self.mutable(prev_layers[-1]) + if self.skipconnect is not None: + connection = self.skipconnect(prev_layers[:-1], tags=prev_labels) + if connection is not None: + out += connection + return self.batch_norm(out) class GeneralNetwork(nn.Module): @@ -62,7 +56,8 @@ def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10, for layer_id in range(self.num_layers): if layer_id in self.pool_layers_idx: self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters)) - self.layers.append(ENASLayer(layer_id, self.out_filters, self.out_filters)) + self.layers.append(ENASLayer("layer_{}".format(layer_id), layer_id, + self.out_filters, self.out_filters)) self.gap = nn.AdaptiveAvgPool2d(1) self.dense = nn.Linear(self.out_filters, self.num_classes) @@ -71,11 +66,12 @@ def forward(self, x): bs = x.size(0) cur = self.stem(x) - layers = [cur] + layers, labels = [cur], [] for layer_id in range(self.num_layers): - cur = self.layers[layer_id](layers) + cur = self.layers[layer_id](layers, labels) layers.append(cur) + labels.append(self.layers[layer_id].key) if layer_id in self.pool_layers_idx: for i, layer in enumerate(layers): layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer) @@ -85,58 +81,3 @@ def forward(self, x): cur = self.dropout(cur) logits = self.dense(cur) return logits - - -def accuracy(output, target, topk=(1,)): - """ Computes the precision@k for the specified values of k """ - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - # one-hot case - if target.ndimension() > 1: - target = target.max(1)[1] - - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = dict() - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0) - res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() - return res - - -def reward_accuracy(output, target, topk=(1,)): - batch_size = target.size(0) - _, predicted = torch.max(output.data, 1) - return (predicted == target).sum().item() / batch_size - - -if __name__ == "__main__": - parser = ArgumentParser("enas") - parser.add_argument("--batch-size", default=3, type=int) - parser.add_argument("--log-frequency", default=1, type=int) - args = parser.parse_args() - - dataset_train, dataset_valid = datasets.get_dataset("cifar10") - - model = GeneralNetwork() - criterion = nn.CrossEntropyLoss() - - n_epochs = 310 - optim = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) - lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs, eta_min=0.001) - - trainer = enas.EnasTrainer(model, - loss=criterion, - metrics=accuracy, - reward_function=reward_accuracy, - optimizer=optim, - lr_scheduler=lr_scheduler, - batch_size=args.batch_size, - num_epochs=n_epochs, - dataset_train=dataset_train, - dataset_valid=dataset_valid, - log_frequency=args.log_frequency) - trainer.train() diff --git a/examples/nas/enas/micro.py b/examples/nas/enas/micro.py new file mode 100644 index 0000000000..209abf2405 --- /dev/null +++ b/examples/nas/enas/micro.py @@ -0,0 +1,183 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nni.nas.pytorch import mutables +from ops import FactorizedReduce, StdConv, SepConvBN, Pool + + +class AuxiliaryHead(nn.Module): + def __init__(self, in_channels, num_classes): + super().__init__() + self.in_channels = in_channels + self.num_classes = num_classes + self.pooling = nn.Sequential( + nn.ReLU(), + nn.AvgPool2d(5, 3, 2) + ) + self.proj = nn.Sequential( + StdConv(in_channels, 128), + StdConv(128, 768) + ) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(768, 10, bias=False) + + def forward(self, x): + bs = x.size(0) + x = self.pooling(x) + x = self.proj(x) + x = self.avg_pool(x).view(bs, -1) + x = self.fc(x) + return x + + +class Cell(nn.Module): + def __init__(self, cell_name, num_prev_layers, channels): + super().__init__() + self.input_choice = mutables.InputChoice(num_prev_layers, n_selected=1, return_mask=True, + key=cell_name + "_input") + self.op_choice = mutables.LayerChoice([ + SepConvBN(channels, channels, 3, 1), + SepConvBN(channels, channels, 5, 2), + Pool("avg", 3, 1, 1), + Pool("max", 3, 1, 1), + nn.Identity() + ], key=cell_name + "_op") + + def forward(self, prev_layers, prev_labels): + chosen_input, chosen_mask = self.input_choice(prev_layers, tags=prev_labels) + cell_out = self.op_choice(chosen_input) + return cell_out, chosen_mask + + +class Node(mutables.MutableScope): + def __init__(self, node_name, num_prev_layers, channels): + super().__init__(node_name) + self.cell_x = Cell(node_name + "_x", num_prev_layers, channels) + self.cell_y = Cell(node_name + "_y", num_prev_layers, channels) + + def forward(self, prev_layers, prev_labels): + out_x, mask_x = self.cell_x(prev_layers, prev_labels) + out_y, mask_y = self.cell_y(prev_layers, prev_labels) + return out_x + out_y, mask_x | mask_y + + +class Calibration(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.process = None + if in_channels != out_channels: + self.process = StdConv(in_channels, out_channels) + + def forward(self, x): + if self.process is None: + return x + return self.process(x) + + +class ReductionLayer(nn.Module): + def __init__(self, in_channels_pp, in_channels_p, out_channels): + super().__init__() + self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False) + self.reduce1 = FactorizedReduce(in_channels_p, out_channels, affine=False) + + def forward(self, pprev, prev): + return self.reduce0(pprev), self.reduce1(prev) + + +class ENASLayer(nn.Module): + def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduction): + super().__init__() + self.preproc0 = Calibration(in_channels_pp, out_channels) + self.preproc1 = Calibration(in_channels_p, out_channels) + + self.num_nodes = num_nodes + name_prefix = "reduce" if reduction else "normal" + self.nodes = nn.ModuleList([Node("{}_node_{}".format(name_prefix, i), + i + 2, out_channels) for i in range(num_nodes)]) + self.final_conv_w = nn.Parameter(torch.zeros(out_channels, self.num_nodes + 2, out_channels, 1, 1), requires_grad=True) + self.bn = nn.BatchNorm2d(out_channels, affine=False) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_normal_(self.final_conv_w) + + def forward(self, pprev, prev): + pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev) + + prev_nodes_out = [pprev_, prev_] + prev_nodes_labels = ["prev1", "prev2"] + nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device) + for i in range(self.num_nodes): + node_out, mask = self.nodes[i](prev_nodes_out, prev_nodes_labels) + nodes_used_mask[:mask.size(0)] |= mask + prev_nodes_out.append(node_out) + prev_nodes_labels.append(self.nodes[i].key) + + unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1) + unused_nodes = F.relu(unused_nodes) + conv_weight = self.final_conv_w[:, ~nodes_used_mask, :, :, :] + conv_weight = conv_weight.view(conv_weight.size(0), -1, 1, 1) + out = F.conv2d(unused_nodes, conv_weight) + return prev, self.bn(out) + + +class MicroNetwork(nn.Module): + def __init__(self, num_layers=2, num_nodes=5, out_channels=24, in_channels=3, num_classes=10, + dropout_rate=0.0, use_aux_heads=False): + super().__init__() + self.num_layers = num_layers + self.use_aux_heads = use_aux_heads + + self.stem = nn.Sequential( + nn.Conv2d(in_channels, out_channels * 3, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_channels * 3) + ) + + pool_distance = self.num_layers // 3 + pool_layers = [pool_distance, 2 * pool_distance + 1] + self.dropout = nn.Dropout(dropout_rate) + + self.layers = nn.ModuleList() + c_pp = c_p = out_channels * 3 + c_cur = out_channels + for layer_id in range(self.num_layers + 2): + reduction = False + if layer_id in pool_layers: + c_cur, reduction = c_p * 2, True + self.layers.append(ReductionLayer(c_pp, c_p, c_cur)) + c_pp = c_p = c_cur + self.layers.append(ENASLayer(num_nodes, c_pp, c_p, c_cur, reduction)) + if self.use_aux_heads and layer_id == pool_layers[-1] + 1: + self.layers.append(AuxiliaryHead(c_cur, num_classes)) + c_pp, c_p = c_p, c_cur + + self.gap = nn.AdaptiveAvgPool2d(1) + self.dense = nn.Linear(c_cur, num_classes) + + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + + def forward(self, x): + bs = x.size(0) + prev = cur = self.stem(x) + aux_logits = None + + for layer in self.layers: + if isinstance(layer, AuxiliaryHead): + if self.training: + aux_logits = layer(cur) + else: + prev, cur = layer(prev, cur) + + cur = self.gap(F.relu(cur)).view(bs, -1) + cur = self.dropout(cur) + logits = self.dense(cur) + + if aux_logits is not None: + return logits, aux_logits + return logits diff --git a/examples/nas/enas/ops.py b/examples/nas/enas/ops.py index 2df9088321..2b9df8069b 100644 --- a/examples/nas/enas/ops.py +++ b/examples/nas/enas/ops.py @@ -19,12 +19,7 @@ class PoolBranch(nn.Module): def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False): super().__init__() self.preproc = StdConv(C_in, C_out) - if pool_type.lower() == 'max': - self.pool = nn.MaxPool2d(kernel_size, stride, padding) - elif pool_type.lower() == 'avg': - self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) - else: - raise ValueError() + self.pool = Pool(pool_type, kernel_size, stride, padding) self.bn = nn.BatchNorm2d(C_out, affine=affine) def forward(self, x): @@ -78,3 +73,31 @@ def forward(self, x): out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) out = self.bn(out) return out + + +class Pool(nn.Module): + def __init__(self, pool_type, kernel_size, stride, padding): + super().__init__() + if pool_type.lower() == 'max': + self.pool = nn.MaxPool2d(kernel_size, stride, padding) + elif pool_type.lower() == 'avg': + self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) + else: + raise ValueError() + + def forward(self, x): + return self.pool(x) + + +class SepConvBN(nn.Module): + def __init__(self, C_in, C_out, kernel_size, padding): + super().__init__() + self.relu = nn.ReLU() + self.conv = SeparableConv(C_in, C_out, kernel_size, 1, padding) + self.bn = nn.BatchNorm2d(C_out, affine=True) + + def forward(self, x): + x = self.relu(x) + x = self.conv(x) + x = self.bn(x) + return x diff --git a/examples/nas/enas/search.py b/examples/nas/enas/search.py new file mode 100644 index 0000000000..6e1bdec34c --- /dev/null +++ b/examples/nas/enas/search.py @@ -0,0 +1,47 @@ +from argparse import ArgumentParser + +import torch +import torch.nn as nn + +import datasets +from macro import GeneralNetwork +from micro import MicroNetwork +from nni.nas.pytorch import enas +from nni.nas.pytorch.callbacks import LearningRateScheduler, ArchitectureCheckpoint +from utils import accuracy, reward_accuracy + +if __name__ == "__main__": + parser = ArgumentParser("enas") + parser.add_argument("--batch-size", default=128, type=int) + parser.add_argument("--log-frequency", default=1, type=int) + parser.add_argument("--search-for", choices=["macro", "micro"], default="macro") + args = parser.parse_args() + + dataset_train, dataset_valid = datasets.get_dataset("cifar10") + if args.search_for == "macro": + model = GeneralNetwork() + num_epochs = 310 + mutator = None + elif args.search_for == "micro": + model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) + num_epochs = 150 + mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True) + else: + raise AssertionError + + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001) + + trainer = enas.EnasTrainer(model, + loss=criterion, + metrics=accuracy, + reward_function=reward_accuracy, + optimizer=optimizer, + callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")], + batch_size=args.batch_size, + num_epochs=num_epochs, + dataset_train=dataset_train, + dataset_valid=dataset_valid, + log_frequency=args.log_frequency) + trainer.train_and_validate() diff --git a/examples/nas/enas/utils.py b/examples/nas/enas/utils.py new file mode 100644 index 0000000000..22bc62819f --- /dev/null +++ b/examples/nas/enas/utils.py @@ -0,0 +1,27 @@ +import torch + + +def accuracy(output, target, topk=(1,)): + """ Computes the precision@k for the specified values of k """ + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + # one-hot case + if target.ndimension() > 1: + target = target.max(1)[1] + + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = dict() + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() + return res + + +def reward_accuracy(output, target, topk=(1,)): + batch_size = target.size(0) + _, predicted = torch.max(output.data, 1) + return (predicted == target).sum().item() / batch_size diff --git a/src/__init__.py b/src/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/sdk/pynni/nni/nas/pytorch/base_mutator.py b/src/sdk/pynni/nni/nas/pytorch/base_mutator.py new file mode 100644 index 0000000000..dd2b844d24 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/base_mutator.py @@ -0,0 +1,70 @@ +import logging + +import torch.nn as nn + +from nni.nas.pytorch.mutables import Mutable + +logger = logging.getLogger(__name__) + + +class BaseMutator(nn.Module): + def __init__(self, model): + super().__init__() + self.__dict__["model"] = model + self.before_parse_search_space() + self._parse_search_space() + self.after_parse_search_space() + + def before_parse_search_space(self): + pass + + def after_parse_search_space(self): + pass + + def _parse_search_space(self): + for name, mutable, _ in self.named_mutables(distinct=False): + mutable.name = name + mutable.set_mutator(self) + + def named_mutables(self, root=None, distinct=True): + if root is None: + root = self.model + # if distinct is true, the method will filter out those with duplicated keys + key2module = dict() + for name, module in root.named_modules(): + if isinstance(module, Mutable): + module_distinct = False + if module.key in key2module: + assert key2module[module.key].similar(module), \ + "Mutable \"{}\" that share the same key must be similar to each other".format(module.key) + else: + module_distinct = True + key2module[module.key] = module + if distinct: + if module_distinct: + yield name, module + else: + yield name, module, module_distinct + + def __setattr__(self, key, value): + if key in ["model", "net", "network"]: + logger.warning("Think twice if you are including the network into mutator.") + return super().__setattr__(key, value) + + def forward(self, *inputs): + raise NotImplementedError("Mutator is not forward-able") + + def enter_mutable_scope(self, mutable_scope): + pass + + def exit_mutable_scope(self, mutable_scope): + pass + + def on_forward_layer_choice(self, mutable, *inputs): + raise NotImplementedError + + def on_forward_input_choice(self, mutable, tensor_list, tags): + raise NotImplementedError + + def export(self): + raise NotImplementedError diff --git a/src/sdk/pynni/nni/nas/pytorch/base_trainer.py b/src/sdk/pynni/nni/nas/pytorch/base_trainer.py new file mode 100644 index 0000000000..1248cc09e2 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/base_trainer.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod + + +class BaseTrainer(ABC): + + @abstractmethod + def train(self): + raise NotImplementedError + + @abstractmethod + def validate(self): + raise NotImplementedError + + @abstractmethod + def train_and_validate(self): + raise NotImplementedError diff --git a/src/sdk/pynni/nni/nas/pytorch/callbacks.py b/src/sdk/pynni/nni/nas/pytorch/callbacks.py new file mode 100644 index 0000000000..2a76b3dab8 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/callbacks.py @@ -0,0 +1,69 @@ +import json +import logging +import os + +import torch + +_logger = logging.getLogger(__name__) + + +class Callback: + + def __init__(self): + self.model = None + self.mutator = None + self.trainer = None + + def build(self, model, mutator, trainer): + self.model = model + self.mutator = mutator + self.trainer = trainer + + def on_epoch_begin(self, epoch): + pass + + def on_epoch_end(self, epoch): + pass + + def on_batch_begin(self, epoch): + pass + + def on_batch_end(self, epoch): + pass + + +class LearningRateScheduler(Callback): + def __init__(self, scheduler, mode="epoch"): + super().__init__() + assert mode == "epoch" + self.scheduler = scheduler + self.mode = mode + + def on_epoch_end(self, epoch): + self.scheduler.step() + + +class ArchitectureCheckpoint(Callback): + class TorchTensorEncoder(json.JSONEncoder): + def default(self, o): # pylint: disable=method-hidden + if isinstance(o, torch.Tensor): + olist = o.tolist() + if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)): + _logger.warning("Every element in %s is either 0 or 1. " + "You might consider convert it into bool.", olist) + return olist + return super().default(o) + + def __init__(self, checkpoint_dir, every="epoch"): + super().__init__() + assert every == "epoch" + self.checkpoint_dir = checkpoint_dir + os.makedirs(self.checkpoint_dir, exist_ok=True) + + def _export_to_file(self, file): + mutator_export = self.mutator.export() + with open(file, "w") as f: + json.dump(mutator_export, f, indent=2, sort_keys=True, cls=self.TorchTensorEncoder) + + def on_epoch_end(self, epoch): + self._export_to_file(os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch))) diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py b/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py index f28d2cd73c..7f2c9f9675 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py @@ -1,4 +1,3 @@ from .mutator import DartsMutator from .trainer import DartsTrainer -from .cnn_cell import CnnCell -from .cnn_network import CnnNetwork +from .scope import DartsNode \ No newline at end of file diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py index ef5dcec806..589847d2b6 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py @@ -3,16 +3,34 @@ from torch.nn import functional as F from nni.nas.pytorch.mutables import LayerChoice -from nni.nas.pytorch.mutator import PyTorchMutator +from nni.nas.pytorch.mutator import Mutator +from .scope import DartsNode -class DartsMutator(PyTorchMutator): +class DartsMutator(Mutator): - def before_build(self, model): + def after_parse_search_space(self): self.choices = nn.ParameterDict() - - def on_init_layer_choice(self, mutable: LayerChoice): - self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length)) + for _, mutable in self.named_mutables(): + if isinstance(mutable, LayerChoice): + self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(len(mutable) + 1)) def on_calc_layer_choice_mask(self, mutable: LayerChoice): - return F.softmax(self.choices[mutable.key], dim=-1) + return F.softmax(self.choices[mutable.key], dim=-1)[:-1] + + def export(self): + result = super().export() + for _, darts_node in self.named_mutables(): + if isinstance(darts_node, DartsNode): + keys, edges_max = [], [] # key of all the layer choices in current node, and their best edge weight + for _, choice in self.named_mutables(darts_node): + if isinstance(choice, LayerChoice): + keys.append(choice.key) + max_val, index = torch.max(result[choice.key], 0) + edges_max.append(max_val) + result[choice.key] = F.one_hot(index, num_classes=len(result[choice.key])).view(-1).bool() + _, topk_edge_indices = torch.topk(torch.tensor(edges_max).view(-1), darts_node.limitation) # pylint: disable=not-callable + for i, key in enumerate(keys): + if i not in topk_edge_indices: + result[key] = torch.zeros_like(result[key]) + return result diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/scope.py b/src/sdk/pynni/nni/nas/pytorch/darts/scope.py new file mode 100644 index 0000000000..a2bf2b3cff --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/darts/scope.py @@ -0,0 +1,11 @@ +from nni.nas.pytorch.mutables import MutableScope + + +class DartsNode(MutableScope): + """ + At most `limitation` choice is activated in a `DartsNode` when exporting. + """ + + def __init__(self, key, limitation): + super().__init__(key) + self.limitation = limitation diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py index 75463ff23f..464832eadf 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py @@ -4,32 +4,18 @@ from torch import nn as nn from nni.nas.pytorch.trainer import Trainer -from nni.nas.utils import AverageMeterGroup, auto_device +from nni.nas.utils import AverageMeterGroup from .mutator import DartsMutator class DartsTrainer(Trainer): def __init__(self, model, loss, metrics, - model_optim, lr_scheduler, num_epochs, dataset_train, dataset_valid, - mutator=None, batch_size=64, workers=4, device=None, log_frequency=None): - self.model = model - self.loss = loss - self.metrics = metrics - self.mutator = mutator - if self.mutator is None: - self.mutator = DartsMutator(model) - self.model_optim = model_optim - self.lr_scheduler = lr_scheduler - self.num_epochs = num_epochs - self.dataset_train = dataset_train - self.dataset_valid = dataset_valid - self.device = auto_device() if device is None else device - self.log_frequency = log_frequency - - self.model.to(self.device) - self.loss.to(self.device) - self.mutator.to(self.device) - + optimizer, num_epochs, dataset_train, dataset_valid, + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, + callbacks=None): + super().__init__(model, loss, metrics, optimizer, num_epochs, + dataset_train, dataset_valid, batch_size, workers, device, log_frequency, + mutator if mutator is not None else DartsMutator(model), callbacks) self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), 3.0E-4, betas=(0.5, 0.999), weight_decay=1.0E-3) n_train = len(self.dataset_train) @@ -46,10 +32,10 @@ def __init__(self, model, loss, metrics, sampler=valid_sampler, num_workers=workers) - def train_epoch(self, epoch): + def train_one_epoch(self, epoch): self.model.train() self.mutator.train() - lr = self.lr_scheduler.get_lr()[0] + lr = self.optimizer.param_groups[0]["lr"] meters = AverageMeterGroup() for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)): trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device) @@ -60,14 +46,14 @@ def train_epoch(self, epoch): # cannot deepcopy model because it will break the reference # phase 1. child network step - self.model_optim.zero_grad() + self.optimizer.zero_grad() with self.mutator.forward_pass(): logits = self.model(trn_X) loss = self.loss(logits, trn_y) loss.backward() # gradient clipping nn.utils.clip_grad_norm_(self.model.parameters(), 5.) - self.model_optim.step() + self.optimizer.step() new_model = copy.deepcopy(self.model.state_dict()) @@ -83,11 +69,9 @@ def train_epoch(self, epoch): metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: - print("Epoch {} Step [{}/{}] {}".format(epoch, step, len(self.train_loader), meters)) - - self.lr_scheduler.step() + print("Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, step, len(self.train_loader), meters)) - def validate_epoch(self, epoch): + def validate_one_epoch(self, epoch): self.model.eval() self.mutator.eval() meters = AverageMeterGroup() @@ -99,17 +83,7 @@ def validate_epoch(self, epoch): metrics = self.metrics(logits, y) meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: - print("Epoch {} Step [{}/{}] {}".format(epoch, step, len(self.valid_loader), meters)) - - def train(self): - for epoch in range(self.num_epochs): - # training - print("Epoch {} Training".format(epoch)) - self.train_epoch(epoch) - - # validation - print("Epoch {} Validating".format(epoch)) - self.validate_epoch(epoch) + print("Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, step, len(self.valid_loader), meters)) def _unrolled_backward(self, trn_X, trn_y, val_X, val_y, backup_model, lr): """ @@ -160,6 +134,3 @@ def _compute_hessian(self, model, dw, trn_X, trn_y): hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)] return hessian - - def export(self): - pass diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py index a158886233..3bd32459b4 100644 --- a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py @@ -2,7 +2,8 @@ import torch.nn as nn import torch.nn.functional as F -from nni.nas.pytorch.mutator import PyTorchMutator +from nni.nas.pytorch.mutables import LayerChoice +from nni.nas.pytorch.mutator import Mutator class StackedLSTMCell(nn.Module): @@ -23,35 +24,49 @@ def forward(self, inputs, hidden): return next_c, next_h -class EnasMutator(PyTorchMutator): - def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, anchor_extra_step=False, - skip_target=0.4): +class EnasMutator(Mutator): + def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False, + skip_target=0.4, branch_bias=0.25): self.lstm_size = lstm_size self.lstm_num_layers = lstm_num_layers self.tanh_constant = tanh_constant - self.max_layer_choice = 0 - self.anchor_extra_step = anchor_extra_step + self.cell_exit_extra_step = cell_exit_extra_step self.skip_target = skip_target + self.branch_bias = branch_bias super().__init__(model) - def before_build(self, model): + def before_parse_search_space(self): self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False) self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.v_attn = nn.Linear(self.lstm_size, 1, bias=False) self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1) - self.skip_targets = nn.Parameter(torch.Tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) - self.cross_entropy_loss = nn.CrossEntropyLoss() + self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable + self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none") + self.bias_dict = nn.ParameterDict() + + def after_parse_search_space(self): + self.max_layer_choice = 0 + for _, mutable in self.named_mutables(): + if isinstance(mutable, LayerChoice): + if self.max_layer_choice == 0: + self.max_layer_choice = mutable.length + assert self.max_layer_choice == mutable.length, \ + "ENAS mutator requires all layer choice have the same number of candidates." + # NOTE(yuge): We might implement an interface later. Judging by key now. + if "reduce" in mutable.key: + def is_conv(choice): + return "conv" in str(type(choice)).lower() + bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable + for choice in mutable.choices]) + self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False) - def after_build(self, model): self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size) - self.soft = nn.Linear(self.lstm_size, self.max_layer_choice) + self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False) def before_pass(self): super().before_pass() self._anchors_hid = dict() - self._selected_layers = [] - self._selected_inputs = [] self._inputs = self.g_emb.data self._c = [torch.zeros((1, self.lstm_size), dtype=self._inputs.dtype, @@ -69,58 +84,58 @@ def _lstm_next_step(self): def _mark_anchor(self, key): self._anchors_hid[key] = self._h[-1] - def on_init_layer_choice(self, mutable): - if self.max_layer_choice == 0: - self.max_layer_choice = mutable.length - assert self.max_layer_choice == mutable.length, \ - "ENAS mutator requires all layer choice have the same number of candidates." - def on_calc_layer_choice_mask(self, mutable): self._lstm_next_step() logit = self.soft(self._h[-1]) if self.tanh_constant is not None: logit = self.tanh_constant * torch.tanh(logit) + if mutable.key in self.bias_dict: + logit += self.bias_dict[mutable.key] branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) log_prob = self.cross_entropy_loss(logit, branch_id) - self.sample_log_prob += log_prob + self.sample_log_prob += torch.sum(log_prob) entropy = (log_prob * torch.exp(-log_prob)).detach() - self.sample_entropy += entropy + self.sample_entropy += torch.sum(entropy) self._inputs = self.embedding(branch_id) - self._selected_layers.append(branch_id.item()) - return F.one_hot(branch_id).bool().view(-1) + return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1) + + def on_calc_input_choice_mask(self, mutable, tags): + query, anchors = [], [] + for label in tags: + if label not in self._anchors_hid: + self._lstm_next_step() + self._mark_anchor(label) # empty loop, fill not found + query.append(self.attn_anchor(self._anchors_hid[label])) + anchors.append(self._anchors_hid[label]) + query = torch.cat(query, 0) + query = torch.tanh(query + self.attn_query(self._h[-1])) + query = self.v_attn(query) + if self.tanh_constant is not None: + query = self.tanh_constant * torch.tanh(query) - def on_calc_input_choice_mask(self, mutable, semantic_labels): if mutable.n_selected is None: - query, anchors = [], [] - for label in semantic_labels: - if label not in self._anchors_hid: - self._lstm_next_step() - self._mark_anchor(label) # empty loop, fill not found - query.append(self.attn_anchor(self._anchors_hid[label])) - anchors.append(self._anchors_hid[label]) - query = torch.cat(query, 0) - query = torch.tanh(query + self.attn_query(self._h[-1])) - query = self.v_attn(query) logit = torch.cat([-query, query], 1) - if self.tanh_constant is not None: - logit = self.tanh_constant * torch.tanh(logit) skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) skip_prob = torch.sigmoid(logit) kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets)) self.sample_skip_penalty += kl - log_prob = self.cross_entropy_loss(logit, skip) - self.sample_log_prob += torch.sum(log_prob) - entropy = (log_prob * torch.exp(-log_prob)).detach() - self.sample_entropy += torch.sum(entropy) - - self.inputs = torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip)) - self._selected_inputs.append(skip) - return skip.bool() + self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0) else: assert mutable.n_selected == 1, "Input choice must select exactly one or any in ENAS." - raise NotImplementedError + logit = query.view(1, -1) + index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) + skip = F.one_hot(index).view(-1) + log_prob = self.cross_entropy_loss(logit, index) + self._inputs = anchors[index.item()] + + self.sample_log_prob += torch.sum(log_prob) + entropy = (log_prob * torch.exp(-log_prob)).detach() + self.sample_entropy += torch.sum(entropy) + return skip.bool() def exit_mutable_scope(self, mutable_scope): + if self.cell_exit_extra_step: + self._lstm_next_step() self._mark_anchor(mutable_scope.key) diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py b/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py index 7bc24ad16f..7d3e493782 100644 --- a/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py @@ -2,39 +2,29 @@ import torch.optim as optim from nni.nas.pytorch.trainer import Trainer -from nni.nas.utils import AverageMeterGroup, auto_device +from nni.nas.utils import AverageMeterGroup from .mutator import EnasMutator class EnasTrainer(Trainer): def __init__(self, model, loss, metrics, reward_function, - optimizer, num_epochs, dataset_train, dataset_valid, lr_scheduler=None, - mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, + optimizer, num_epochs, dataset_train, dataset_valid, + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, - mutator_lr=0.00035): - self.model = model - self.loss = loss - self.metrics = metrics + mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4): + super().__init__(model, loss, metrics, optimizer, num_epochs, + dataset_train, dataset_valid, batch_size, workers, device, log_frequency, + mutator if mutator is not None else EnasMutator(model), callbacks) self.reward_function = reward_function - self.mutator = mutator - if self.mutator is None: - self.mutator = EnasMutator(model) - self.optim = optimizer - self.mut_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr) - self.lr_scheduler = lr_scheduler - self.num_epochs = num_epochs - self.dataset_train = dataset_train - self.dataset_valid = dataset_valid - self.device = auto_device() if device is None else device - self.log_frequency = log_frequency + self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr) + self.entropy_weight = entropy_weight self.skip_weight = skip_weight self.baseline_decay = baseline_decay self.baseline = 0. - - self.model.to(self.device) - self.loss.to(self.device) - self.mutator.to(self.device) + self.mutator_steps_aggregate = mutator_steps_aggregate + self.mutator_steps = mutator_steps + self.aux_weight = aux_weight n_train = len(self.dataset_train) split = n_train // 10 @@ -53,68 +43,76 @@ def __init__(self, model, loss, metrics, reward_function, batch_size=batch_size, num_workers=workers) - def train_epoch(self, epoch): + def train_one_epoch(self, epoch): + # Sample model and train self.model.train() - self.mutator.train() + self.mutator.eval() + meters = AverageMeterGroup() + for step, (x, y) in enumerate(self.train_loader): + x, y = x.to(self.device), y.to(self.device) + self.optimizer.zero_grad() + + with self.mutator.forward_pass(): + logits = self.model(x) - for phase in ["model", "mutator"]: - if phase == "model": - self.model.train() - self.mutator.eval() + if isinstance(logits, tuple): + logits, aux_logits = logits + aux_loss = self.loss(aux_logits, y) else: - self.model.eval() - self.mutator.train() - loader = self.train_loader if phase == "model" else self.valid_loader - meters = AverageMeterGroup() - for step, (x, y) in enumerate(loader): + aux_loss = 0. + metrics = self.metrics(logits, y) + loss = self.loss(logits, y) + loss = loss + self.aux_weight * aux_loss + loss.backward() + self.optimizer.step() + metrics["loss"] = loss.item() + meters.update(metrics) + + if self.log_frequency is not None and step % self.log_frequency == 0: + print("Model Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, + step, len(self.train_loader), meters)) + + # Train sampler (mutator) + self.model.eval() + self.mutator.train() + meters = AverageMeterGroup() + mutator_step, total_mutator_steps = 0, self.mutator_steps * self.mutator_steps_aggregate + while mutator_step < total_mutator_steps: + for step, (x, y) in enumerate(self.valid_loader): x, y = x.to(self.device), y.to(self.device) - self.optim.zero_grad() - self.mut_optim.zero_grad() with self.mutator.forward_pass(): logits = self.model(x) metrics = self.metrics(logits, y) - - if phase == "model": - loss = self.loss(logits, y) - loss.backward() - self.optim.step() - else: - reward = self.reward_function(logits, y) - if self.entropy_weight is not None: - reward += self.entropy_weight * self.mutator.sample_entropy - self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) - self.baseline = self.baseline.detach().item() - loss = self.mutator.sample_log_prob * (reward - self.baseline) - if self.skip_weight: - loss += self.skip_weight * self.mutator.sample_skip_penalty - loss.backward() - self.mut_optim.step() - metrics["reward"] = reward + reward = self.reward_function(logits, y) + if self.entropy_weight is not None: + reward += self.entropy_weight * self.mutator.sample_entropy + self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) + self.baseline = self.baseline.detach().item() + loss = self.mutator.sample_log_prob * (reward - self.baseline) + if self.skip_weight: + loss += self.skip_weight * self.mutator.sample_skip_penalty + metrics["reward"] = reward metrics["loss"] = loss.item() - meters.update(metrics) - - if self.log_frequency is not None and step % self.log_frequency == 0: - print("Epoch {} {} Step [{}/{}] {}".format(epoch, phase.capitalize(), step, - len(loader), meters)) - # print(self.mutator._selected_layers) - # print(self.mutator._selected_inputs) - - if self.lr_scheduler is not None: - self.lr_scheduler.step() + metrics["ent"] = self.mutator.sample_entropy.item() + metrics["baseline"] = self.baseline + metrics["skip"] = self.mutator.sample_skip_penalty - def validate_epoch(self, epoch): - pass + loss = loss / self.mutator_steps_aggregate + loss.backward() + meters.update(metrics) - def train(self): - for epoch in range(self.num_epochs): - # training - print("Epoch {} Training".format(epoch)) - self.train_epoch(epoch) + if mutator_step % self.mutator_steps_aggregate == 0: + self.mutator_optim.step() + self.mutator_optim.zero_grad() - # validation - print("Epoch {} Validating".format(epoch)) - self.validate_epoch(epoch) + if self.log_frequency is not None and step % self.log_frequency == 0: + print("Mutator Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, + mutator_step // self.mutator_steps_aggregate, + self.mutator_steps, meters)) + mutator_step += 1 + if mutator_step >= total_mutator_steps: + break - def export(self): + def validate_one_epoch(self, epoch): pass diff --git a/src/sdk/pynni/nni/nas/pytorch/fixed.py b/src/sdk/pynni/nni/nas/pytorch/fixed.py new file mode 100644 index 0000000000..526d66b610 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/fixed.py @@ -0,0 +1,58 @@ +import json + +import torch + +from nni.nas.pytorch.mutator import Mutator + + +class FixedArchitecture(Mutator): + def __init__(self, model, fixed_arc, strict=True): + """ + Initialize a fixed architecture mutator. + + Parameters + ---------- + model: nn.Module + A mutable network. + fixed_arc: str or dict + Path to the architecture checkpoint (a string), or preloaded architecture object (a dict). + strict: bool + Force everything that appears in `fixed_arc` to be used at least once. + """ + super().__init__(model) + if isinstance(fixed_arc, str): + with open(fixed_arc, "r") as f: + fixed_arc = json.load(f.read()) + self._fixed_arc = fixed_arc + self._strict = strict + + def _encode_tensor(self, data): + if isinstance(data, list): + if all(map(lambda o: isinstance(o, bool), data)): + return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable + else: + return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable + if isinstance(data, dict): + return {k: self._encode_tensor(v) for k, v in data.items()} + return data + + def before_pass(self): + self._unused_key = set(self._fixed_arc.keys()) + + def after_pass(self): + if self._strict: + if self._unused_key: + raise ValueError("{} are never used by the network. " + "Set strict=False if you want to disable this check.".format(self._unused_key)) + + def _check_key(self, key): + if key not in self._fixed_arc: + raise ValueError("\"{}\" is demanded by the network, but not found in saved architecture.".format(key)) + + def on_calc_layer_choice_mask(self, mutable): + self._check_key(mutable.key) + return self._fixed_arc[mutable.key] + + def on_calc_input_choice_mask(self, mutable, tags): + self._check_key(mutable.key) + return self._fixed_arc[mutable.key] diff --git a/src/sdk/pynni/nni/nas/pytorch/mutables.py b/src/sdk/pynni/nni/nas/pytorch/mutables.py index e28af84037..16b73b903d 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutables.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutables.py @@ -3,7 +3,7 @@ from nni.nas.utils import global_mutable_counting -class PyTorchMutable(nn.Module): +class Mutable(nn.Module): """ Mutable is designed to function as a normal layer, with all necessary operators' weights. States and weights of architectures should be included in mutator, instead of the layer itself. @@ -24,15 +24,11 @@ def __init__(self, key=None): self._key = key else: self._key = self.__class__.__name__ + str(global_mutable_counting()) - self.name = self.key + self.init_hook = self.forward_hook = None def __deepcopy__(self, memodict=None): raise NotImplementedError("Deep copy doesn't work for mutables.") - def __enter__(self): - self._check_built() - return super().__enter__() - def __call__(self, *args, **kwargs): self._check_built() return super().__call__(*args, **kwargs) @@ -47,8 +43,16 @@ def forward(self, *inputs): def key(self): return self._key + @property + def name(self): + return self._name if hasattr(self, "_name") else "_key" + + @name.setter + def name(self, name): + self._name = name + def similar(self, other): - return self == other + return type(self) == type(other) def _check_built(self): if not hasattr(self, "mutator"): @@ -56,8 +60,11 @@ def _check_built(self): "Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__" "so that trainer can locate all your mutables. See NNI docs for more details.".format(self)) + def __repr__(self): + return "{} ({})".format(self.name, self.key) -class MutableScope(PyTorchMutable): + +class MutableScope(Mutable): """ Mutable scope labels a subgraph to help mutators make better decisions. Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch @@ -67,14 +74,18 @@ class MutableScope(PyTorchMutable): def __init__(self, key): super().__init__(key=key) - def __enter__(self): - self.mutator.enter_mutable_scope(self) + def build(self): + self.mutator.on_init_mutable_scope(self) - def __exit__(self, exc_type, exc_val, exc_tb): - self.mutator.exit_mutable_scope(self) + def __call__(self, *args, **kwargs): + try: + self.mutator.enter_mutable_scope(self) + return super().__call__(*args, **kwargs) + finally: + self.mutator.exit_mutable_scope(self) -class LayerChoice(PyTorchMutable): +class LayerChoice(Mutable): def __init__(self, op_candidates, reduction="mean", return_mask=False, key=None): super().__init__(key=key) self.length = len(op_candidates) @@ -83,10 +94,10 @@ def __init__(self, op_candidates, reduction="mean", return_mask=False, key=None) self.return_mask = return_mask def __len__(self): - return self.length + return len(self.choices) def forward(self, *inputs): - out, mask = self.mutator.on_forward(self, *inputs) + out, mask = self.mutator.on_forward_layer_choice(self, *inputs) if self.return_mask: return out, mask return out @@ -95,7 +106,7 @@ def similar(self, other): return type(self) == type(other) and self.length == other.length -class InputChoice(PyTorchMutable): +class InputChoice(Mutable): def __init__(self, n_candidates, n_selected=None, reduction="mean", return_mask=False, key=None): super().__init__(key=key) assert n_candidates > 0, "Number of candidates must be greater than 0." @@ -104,16 +115,21 @@ def __init__(self, n_candidates, n_selected=None, reduction="mean", return_mask= self.reduction = reduction self.return_mask = return_mask - def forward(self, optional_inputs, semantic_labels=None): + def build(self): + self.mutator.on_init_input_choice(self) + + def forward(self, optional_inputs, tags=None): assert len(optional_inputs) == self.n_candidates, \ "Length of the input list must be equal to number of candidates." - if semantic_labels is None: - semantic_labels = ["default_label"] * self.n_candidates - out, mask = self.mutator.on_forward(self, optional_inputs, semantic_labels) + if tags is None: + tags = [""] * self.n_candidates + else: + assert len(tags) == self.n_candidates, "Length of tags must be equal to number of candidates." + out, mask = self.mutator.on_forward_input_choice(self, optional_inputs, tags) if self.return_mask: return out, mask return out def similar(self, other): return type(self) == type(other) and \ - self.n_candidates == other.n_candidates and self.n_selected and other.n_selected + self.n_candidates == other.n_candidates and self.n_selected and other.n_selected diff --git a/src/sdk/pynni/nni/nas/pytorch/mutator.py b/src/sdk/pynni/nni/nas/pytorch/mutator.py index 55d742e9af..21d39545e7 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutator.py @@ -1,107 +1,48 @@ -import logging from contextlib import contextmanager import torch import torch.nn as nn -from nni.nas.pytorch.mutables import PyTorchMutable -from nni.nas.utils import to_snake_case +from nni.nas.pytorch.base_mutator import BaseMutator -logger = logging.getLogger(__name__) +class Mutator(BaseMutator, nn.Module): -class PyTorchMutator(nn.Module): - def __init__(self, model): - super().__init__() - self.before_build(model) - self.parse_search_space(model) - self.after_build(model) - - def before_build(self, model): - pass - - def after_build(self, model): - pass - - def named_mutables(self, model): - # if distinct is true, the method will filter out those with duplicated keys - key2module = dict() - for name, module in model.named_modules(): - if isinstance(module, PyTorchMutable): - distinct = False - if module.key in key2module: - assert key2module[module.key].similar(module), \ - "Mutable \"{}\" that share the same key must be similar to each other".format(module.key) - else: - distinct = True - key2module[module.key] = module - yield name, module, distinct - - def __setattr__(self, key, value): - if key in ["model", "net", "network"]: - logger.warning("Think twice if you are including the network into mutator.") - return super().__setattr__(key, value) - - def parse_search_space(self, model): - for name, mutable, distinct in self.named_mutables(model): - mutable.name = name - mutable.set_mutator(self) - if not distinct: - continue - init_method_name = "on_init_{}".format(to_snake_case(mutable.__class__.__name__)) - if hasattr(self, init_method_name) and callable(getattr(self, init_method_name)): - getattr(self, init_method_name)(mutable) - else: - # fallback to general init - self.on_init_general(mutable) - - def on_init_general(self, mutable): - pass + def export(self): + if self._in_forward_pass: + raise RuntimeError("Still in forward pass. Exporting might induce incompleteness.") + if not self._cache: + raise RuntimeError("No running history found. You need to call your model at least once before exporting. " + "You might also want to check if there are no valid mutables in your model.") + return self._cache @contextmanager def forward_pass(self): + self._in_forward_pass = True + self._cache = dict() self.before_pass() try: yield self finally: self.after_pass() + self._in_forward_pass = False def before_pass(self): - self._in_forward_pass = True - self._cache = dict() - - def after_pass(self): - self._in_forward_pass = False - - def enter_mutable_scope(self, mutable_scope): pass - def exit_mutable_scope(self, mutable_scope): + def after_pass(self): pass - def forward(self, *inputs): - raise NotImplementedError("Mutator is not forward-able") - - def on_forward(self, mutable, *inputs): - """Callback on forwarding a mutable""" + def _check_in_forward_pass(self): if not hasattr(self, "_in_forward_pass") or not self._in_forward_pass: raise ValueError("Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call " "super().before_pass() and after_pass() in your override method?") - forward_method_name = "on_forward_{}".format(to_snake_case(mutable.__class__.__name__)) - if hasattr(self, forward_method_name) and callable(getattr(self, forward_method_name)): - return getattr(self, forward_method_name)(mutable, *inputs) - else: - # fallback to general forward - return self.on_forward_general(mutable, *inputs) - - def on_forward_general(self, mutable, *inputs): - raise NotImplementedError("Forward has to be implemented") def on_forward_layer_choice(self, mutable, *inputs): """ Callback of layer choice forward. Override if you are an advanced user. On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers - (either by switch or by weights), then it will reduce the list of all tensor outputs with the policy speicified + (either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`. Parameters @@ -111,33 +52,38 @@ def on_forward_layer_choice(self, mutable, *inputs): Returns ------- - torch.Tensor + tuple of torch.Tensor and torch.Tensor """ + self._check_in_forward_pass() + def _map_fn(op, *inputs): return op(*inputs) + mask = self._cache.setdefault(mutable.key, self.on_calc_layer_choice_mask(mutable)) out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask) return self._tensor_reduction(mutable.reduction, out), mask - def on_forward_input_choice(self, mutable, tensor_list, semantic_labels): + def on_forward_input_choice(self, mutable, tensor_list, tags): """ Callback of input choice forward. Override if you are an advanced user. - On default, this method calls :meth:`on_calc_input_choice_mask` with `semantic_labels` + On default, this method calls :meth:`on_calc_input_choice_mask` with `tags` to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce - the list of all tensor outputs with the policy speicified in `mutable.reduction`. It will also cache the + the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`. Parameters ---------- mutable: InputChoice - inputs: list of torch.Tensor + tensor_list: list of torch.Tensor + tags: list of string Returns ------- - torch.Tensor + tuple of torch.Tensor and torch.Tensor """ - mask = self._cache.setdefault(mutable.key, self.on_calc_input_choice_mask(mutable, semantic_labels)) - out = self._select_with_mask(lambda x: x, [(t, ) for t in tensor_list], mask) + self._check_in_forward_pass() + mask = self._cache.setdefault(mutable.key, self.on_calc_input_choice_mask(mutable, tags)) + out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask) return self._tensor_reduction(mutable.reduction, out), mask def on_calc_layer_choice_mask(self, mutable): @@ -157,7 +103,7 @@ def on_calc_layer_choice_mask(self, mutable): """ raise NotImplementedError("Layer choice mask calculation must be implemented") - def on_calc_input_choice_mask(self, mutable, semantic_labels): + def on_calc_input_choice_mask(self, mutable, tags): """ Recommended to override. Calculate a mask tensor for a input choice. @@ -165,7 +111,7 @@ def on_calc_input_choice_mask(self, mutable, semantic_labels): ---------- mutable: InputChoice Corresponding input choice object. - semantic_labels: list of string + tags: list of string The name of labels of input tensors given by user. Usually it's a :class:`~nni.nas.pytorch.mutables.MutableScope` marked by user. @@ -179,7 +125,6 @@ def on_calc_input_choice_mask(self, mutable, semantic_labels): def _select_with_mask(self, map_fn, candidates, mask): if "BoolTensor" in mask.type(): - # print(candidates[0], len(mask)) out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m] elif "FloatTensor" in mask.type(): out = [map_fn(*cand) * m for cand, m in zip(candidates, mask)] diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py index 6425e234d8..4d9c231143 100644 --- a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py @@ -33,13 +33,13 @@ def train(self): for epoch in range(self.pdarts_epoch): layers = self.layers+self.pdarts_num_layers[epoch] - model, loss, model_optim, lr_scheduler = self.model_creator( + model, loss, model_optim, _ = self.model_creator( layers, n_nodes) mutator = PdartsMutator( model, epoch, self.pdarts_num_to_drop, switches) - self.trainer = DartsTrainer(model, loss=loss, model_optim=model_optim, - lr_scheduler=lr_scheduler, mutator=mutator, **self.darts_parameters) + self.trainer = DartsTrainer(model, loss=loss, optimizer=model_optim, + mutator=mutator, **self.darts_parameters) print("start pdrats training %s..." % epoch) self.trainer.train() diff --git a/src/sdk/pynni/nni/nas/pytorch/trainer.py b/src/sdk/pynni/nni/nas/pytorch/trainer.py index 6327e9a229..ab18e6c6e5 100644 --- a/src/sdk/pynni/nni/nas/pytorch/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/trainer.py @@ -1,12 +1,65 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod +import torch -class Trainer(ABC): +from .base_trainer import BaseTrainer + + +class Trainer(BaseTrainer): + def __init__(self, model, loss, metrics, optimizer, num_epochs, + dataset_train, dataset_valid, batch_size, workers, device, log_frequency, + mutator, callbacks): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device + self.model = model + self.loss = loss + self.metrics = metrics + self.optimizer = optimizer + self.mutator = mutator + + self.model.to(self.device) + self.loss.to(self.device) + self.mutator.to(self.device) + + self.num_epochs = num_epochs + self.dataset_train = dataset_train + self.dataset_valid = dataset_valid + self.batch_size = batch_size + self.workers = workers + self.log_frequency = log_frequency + self.callbacks = callbacks if callbacks is not None else [] + for callback in self.callbacks: + callback.build(self.model, self.mutator, self) @abstractmethod - def train(self): - raise NotImplementedError + def train_one_epoch(self, epoch): + pass @abstractmethod - def export(self): - raise NotImplementedError + def validate_one_epoch(self, epoch): + pass + + def _train(self, validate): + for epoch in range(self.num_epochs): + for callback in self.callbacks: + callback.on_epoch_begin(epoch) + + # training + print("Epoch {} Training".format(epoch)) + self.train_one_epoch(epoch) + + if validate: + # validation + print("Epoch {} Validating".format(epoch)) + self.validate_one_epoch(epoch) + + for callback in self.callbacks: + callback.on_epoch_end(epoch) + + def train_and_validate(self): + self._train(True) + + def train(self): + self._train(False) + + def validate(self): + self.validate_one_epoch(-1) diff --git a/src/sdk/pynni/nni/nas/utils.py b/src/sdk/pynni/nni/nas/utils.py index f6d5dfef65..5000946e7e 100644 --- a/src/sdk/pynni/nni/nas/utils.py +++ b/src/sdk/pynni/nni/nas/utils.py @@ -1,8 +1,5 @@ -import re from collections import OrderedDict -import torch - _counter = 0 @@ -12,14 +9,6 @@ def global_mutable_counting(): return _counter -def to_snake_case(camel_case): - return re.sub('(?!^)([A-Z]+)', r'_\1', camel_case).lower() - - -def auto_device(): - return torch.device("cuda" if torch.cuda.is_available() else "cpu") - - class AverageMeterGroup(object): def __init__(self): From 9dda5370405d33ade5966f8c840bc15f2893c205 Mon Sep 17 00:00:00 2001 From: Chi Song <27178119+squirrelsc@users.noreply.github.com> Date: Mon, 18 Nov 2019 20:59:18 +0800 Subject: [PATCH 06/10] update overview document of NAS (#1744) --- docs/en_US/NAS/Overview.md | 108 ++++++++++++++++++------------------- 1 file changed, 52 insertions(+), 56 deletions(-) diff --git a/docs/en_US/NAS/Overview.md b/docs/en_US/NAS/Overview.md index bedf503b79..92b06b413f 100644 --- a/docs/en_US/NAS/Overview.md +++ b/docs/en_US/NAS/Overview.md @@ -1,66 +1,62 @@ -# NNI Programming Interface for Neural Architecture Search (NAS) - -*This is an experimental feature, programming APIs are almost done, NAS trainers are under intensive development. ([NAS annotation](../AdvancedFeature/GeneralNasInterfaces.md) will become deprecated in future)* - -Automatic neural architecture search is taking an increasingly important role on finding better models. Recent research works have proved the feasibility of automatic NAS, and also found some models that could beat manually designed and tuned models. Some of representative works are [NASNet][2], [ENAS][1], [DARTS][3], [Network Morphism][4], and [Evolution][5]. There are new innovations keeping emerging. However, it takes great efforts to implement those algorithms, and it is hard to reuse code base of one algorithm for implementing another. - -To facilitate NAS innovations (e.g., design/implement new NAS models, compare different NAS models side-by-side), an easy-to-use and flexible programming interface is crucial. - -## Programming interface - -A new programming interface for designing and searching for a model is often demanded in two scenarios. - - 1. When designing a neural network, the designer may have multiple choices for a layer, sub-model, or connection, and not sure which one or a combination performs the best. It would be appealing to have an easy way to express the candidate layers/sub-models they want to try. - 2. For the researchers who are working on automatic NAS, they want to have an unified way to express the search space of neural architectures. And making unchanged trial code adapted to different searching algorithms. - -For expressing neural architecture search space, we provide two APIs: - -```python -# choose one ``op`` from ``ops``, for pytorch this is a module. -# ops: for pytorch ``ops`` is a list of modules, for tensorflow it is a list of keras layers. An example in pytroch: -# ops = [PoolBN('max', channels, 3, stride, 1, affine=False), -# PoolBN('avg', channels, 3, stride, 1, affine=False), -# FactorizedReduce(channels, channels, affine=False), -# SepConv(channels, channels, 3, stride, 1, affine=False), -# DilConv(channels, channels, 3, stride, 2, 2, affine=False)] -# key: the name of this ``LayerChoice`` instance -nni.nas.LayerChoice(ops, key) -# choose ``n_selected`` from ``n_candidates`` inputs. -# n_candidates: the number of candidate inputs -# n_selected: the number of chosen inputs -# reduction: reduction operation for the chosen inputs -# key: the name of this ``InputChoice`` instance -nni.nas.InputChoice(n_candidates, n_selected, reduction, key) +# Neural Architecture Search (NAS) on NNI + +Automatic neural architecture search is taking an increasingly important role on finding better models. Recent research works have proved the feasibility of automatic NAS, and also found some models that could beat manually designed and tuned models. Some of representative works are [NASNet][2], [ENAS][1], [DARTS][3], [Network Morphism][4], and [Evolution][5]. There are new innovations keeping emerging. + +However, it takes great efforts to implement NAS algorithms, and it is hard to reuse code base of existing algorithms in new one. To facilitate NAS innovations (e.g., design and implement new NAS models, compare different NAS models side-by-side), an easy-to-use and flexible programming interface is crucial. + +With this motivation, our ambition is to provide a unified architecture in NNI, to accelerate innovations on NAS, and apply state-of-art algorithms on real world problems faster. + +## Supported algorithms + +NNI supports below NAS algorithms now, and being adding more. User can reproduce an algorithm, or use it on owned dataset. we also encourage user to implement other algorithms with [NNI API](#use-nni-api), to benefit more people. + +Note, these algorithms run standalone without nnictl, and supports PyTorch only. + +### DARTS + +The main contribution of [DARTS: Differentiable Architecture Search][3] on algorithm is to introduce a novel algorithm for differentiable network architecture search on bilevel optimization. + +#### Usage + +```bash +### In case NNI code is not cloned. +git clone https://github.com/Microsoft/nni.git + +cd examples/nas/darts +python search.py ``` -After writing your model with search space embedded in the model using the above two APIs, the next step is finding the best model from the search space. Similar to optimizers of deep learning models, the procedure of finding the best model from search space can be viewed as a type of optimizing process, we call it `NAS trainer`. There have been several NAS trainers, for example, `DartsTrainer` which uses SGD to train architecture weights and model weights iteratively, `ENASTrainer` which uses a controller to train the model. New and more efficient NAS trainers keep emerging in research community. - -NNI provides some popular NAS trainers, to use a NAS trainer, users could initialize a trainer after the model is defined: - -```python -# create a DartsTrainer -trainer = DartsTrainer(model, - loss=criterion, - metrics=lambda output, target: accuracy(output, target, topk=(1,)), - model_optim=optim, - lr_scheduler=lr_scheduler, - num_epochs=50, - dataset_train=dataset_train, - dataset_valid=dataset_valid, - batch_size=args.batch_size, - log_frequency=args.log_frequency) -# finding the best model from search space -trainer.train() -# export the best found model -trainer.export_model() +### P-DARTS + +[Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation](https://arxiv.org/abs/1904.12760) bases on DARTS(#DARTS). It main contribution on algorithm is to introduce an efficient algorithm which allows the depth of searched architectures to grow gradually during the training procedure. + +#### Usage + +```bash +### In case NNI code is not cloned. +git clone https://github.com/Microsoft/nni.git + +cd examples/nas/pdarts +python main.py ``` -Different trainers could have different input arguments depending on their algorithms. After training, users could export the best one of the found models through `trainer.export_model()`. +## Use NNI API + +NOTE, we are trying to support various NAS algorithms with unified programming interface, and it's in very experimental stage. It means the current programing interface may be updated significantly. + +*previous [NAS annotation](../AdvancedFeature/GeneralNasInterfaces.md) interface will be deprecated soon.* + +### Programming interface + +The programming interface of designing and searching a model is often demanded in two scenarios. + +1. When designing a neural network, there may be multiple operation choices on a layer, sub-model, or connection, and it's undetermined which one or combination performs best. So it needs an easy way to express the candidate layers or sub-models. +2. When applying NAS on a neural network, it needs an unified way to express the search space of architectures, so that it doesn't need to update trial code for different searching algorithms. -[Here](https://github.com/microsoft/nni/blob/dev-nas-refactor/examples/nas/darts/main.py) is a trial example using DartsTrainer. +NNI proposed API is [here](https://github.com/microsoft/nni/tree/dev-nas-refactor/src/sdk/pynni/nni/nas/pytorch). And [here](https://github.com/microsoft/nni/tree/dev-nas-refactor/examples/nas/darts) is an example of NAS implementation, which bases on NNI proposed interface. [1]: https://arxiv.org/abs/1802.03268 [2]: https://arxiv.org/abs/1707.07012 [3]: https://arxiv.org/abs/1806.09055 [4]: https://arxiv.org/abs/1806.10282 -[5]: https://arxiv.org/abs/1703.01041 \ No newline at end of file +[5]: https://arxiv.org/abs/1703.01041 From 77e91e8bc57cefadc8607a0866ab51ba33b01762 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Thu, 21 Nov 2019 15:19:29 +0800 Subject: [PATCH 07/10] Extract controller from mutator to make offline decisions (#1758) --- examples/nas/.gitignore | 3 +- examples/nas/darts/model.py | 35 +++-- examples/nas/darts/retrain.py | 143 ++++++++++++++++++ examples/nas/darts/search.py | 4 +- examples/nas/enas/macro.py | 20 +-- examples/nas/enas/micro.py | 33 ++-- examples/nas/enas/search.py | 7 +- src/sdk/pynni/nni/nas/pytorch/base_mutator.py | 139 ++++++++++++----- src/sdk/pynni/nni/nas/pytorch/base_trainer.py | 6 +- src/sdk/pynni/nni/nas/pytorch/callbacks.py | 20 +-- .../pynni/nni/nas/pytorch/darts/__init__.py | 3 +- .../pynni/nni/nas/pytorch/darts/mutator.py | 58 ++++--- src/sdk/pynni/nni/nas/pytorch/darts/scope.py | 11 -- .../pynni/nni/nas/pytorch/darts/trainer.py | 29 ++-- src/sdk/pynni/nni/nas/pytorch/enas/mutator.py | 56 ++++--- src/sdk/pynni/nni/nas/pytorch/enas/trainer.py | 22 +-- src/sdk/pynni/nni/nas/pytorch/fixed.py | 90 ++++++----- src/sdk/pynni/nni/nas/pytorch/mutables.py | 108 +++++++++---- src/sdk/pynni/nni/nas/pytorch/mutator.py | 138 ++++++++--------- .../pynni/nni/nas/pytorch/pdarts/mutator.py | 6 +- .../pynni/nni/nas/pytorch/pdarts/trainer.py | 3 +- src/sdk/pynni/nni/nas/pytorch/trainer.py | 41 +++-- src/sdk/pynni/nni/nas/pytorch/utils.py | 107 +++++++++++++ 23 files changed, 738 insertions(+), 344 deletions(-) create mode 100644 examples/nas/darts/retrain.py delete mode 100644 src/sdk/pynni/nni/nas/pytorch/darts/scope.py create mode 100644 src/sdk/pynni/nni/nas/pytorch/utils.py diff --git a/examples/nas/.gitignore b/examples/nas/.gitignore index 8705cba4d6..9ba06a7ca3 100644 --- a/examples/nas/.gitignore +++ b/examples/nas/.gitignore @@ -1 +1,2 @@ -data +data +checkpoints diff --git a/examples/nas/darts/model.py b/examples/nas/darts/model.py index 5c284b5a46..6a9afe6ff3 100644 --- a/examples/nas/darts/model.py +++ b/examples/nas/darts/model.py @@ -2,7 +2,7 @@ import torch.nn as nn import ops -from nni.nas.pytorch import mutables, darts +from nni.nas.pytorch import mutables class AuxiliaryHead(nn.Module): @@ -31,12 +31,14 @@ def forward(self, x): return logits -class Node(darts.DartsNode): - def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect, drop_path_prob=0.): - super().__init__(node_id, limitation=2) +class Node(nn.Module): + def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect): + super().__init__() self.ops = nn.ModuleList() + choice_keys = [] for i in range(num_prev_nodes): stride = 2 if i < num_downsample_connect else 1 + choice_keys.append("{}_p{}".format(node_id, i)) self.ops.append( mutables.LayerChoice( [ @@ -48,18 +50,19 @@ def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect, dr ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False), ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False), ], - key="{}_p{}".format(node_id, i))) - self.drop_path = ops.DropPath_(drop_path_prob) + key=choice_keys[-1])) + self.drop_path = ops.DropPath_() + self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id)) def forward(self, prev_nodes): assert len(self.ops) == len(prev_nodes) out = [op(node) for op, node in zip(self.ops, prev_nodes)] - return sum(self.drop_path(o) for o in out if o is not None) + return self.input_switch(out) class Cell(nn.Module): - def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction, drop_path_prob=0.): + def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): super().__init__() self.reduction = reduction self.n_nodes = n_nodes @@ -74,10 +77,9 @@ def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, redu # generate dag self.mutable_ops = nn.ModuleList() - for depth in range(self.n_nodes): - self.mutable_ops.append(Node("r{:d}_n{}".format(reduction, depth), - depth + 2, channels, 2 if reduction else 0, - drop_path_prob=drop_path_prob)) + for depth in range(2, self.n_nodes + 2): + self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth), + depth, channels, 2 if reduction else 0)) def forward(self, s0, s1): # s0, s1 are the outputs of previous previous cell and previous cell, respectively. @@ -93,7 +95,7 @@ def forward(self, s0, s1): class CNN(nn.Module): def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4, - stem_multiplier=3, auxiliary=False, drop_path_prob=0.): + stem_multiplier=3, auxiliary=False): super().__init__() self.in_channels = in_channels self.channels = channels @@ -120,7 +122,7 @@ def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nod c_cur *= 2 reduction = True - cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction, drop_path_prob=drop_path_prob) + cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction) self.cells.append(cell) c_cur_out = c_cur * n_nodes channels_pp, channels_p = channels_p, c_cur_out @@ -147,3 +149,8 @@ def forward(self, x): if aux_logits is not None: return logits, aux_logits return logits + + def drop_path_prob(self, p): + for module in self.modules(): + if isinstance(module, ops.DropPath_): + module.p = p diff --git a/examples/nas/darts/retrain.py b/examples/nas/darts/retrain.py new file mode 100644 index 0000000000..5c8fabf8d0 --- /dev/null +++ b/examples/nas/darts/retrain.py @@ -0,0 +1,143 @@ +import logging +from argparse import ArgumentParser + +import torch +import torch.nn as nn + +import datasets +import utils +from model import CNN +from nni.nas.pytorch.fixed import apply_fixed_architecture +from nni.nas.pytorch.utils import AverageMeter + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def train(config, train_loader, model, optimizer, criterion, epoch): + top1 = AverageMeter("top1") + top5 = AverageMeter("top5") + losses = AverageMeter("losses") + + cur_step = epoch * len(train_loader) + cur_lr = optimizer.param_groups[0]['lr'] + logger.info("Epoch %d LR %.6f", epoch, cur_lr) + + model.train() + + for step, (x, y) in enumerate(train_loader): + x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) + bs = x.size(0) + + optimizer.zero_grad() + logits, aux_logits = model(x) + loss = criterion(logits, y) + if config.aux_weight > 0.: + loss += config.aux_weight * criterion(aux_logits, y) + loss.backward() + # gradient clipping + nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + optimizer.step() + + accuracy = utils.accuracy(logits, y, topk=(1, 5)) + losses.update(loss.item(), bs) + top1.update(accuracy["acc1"], bs) + top5.update(accuracy["acc5"], bs) + + if step % config.log_frequency == 0 or step == len(train_loader) - 1: + logger.info( + "Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " + "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( + epoch + 1, config.epochs, step, len(train_loader) - 1, losses=losses, + top1=top1, top5=top5)) + + cur_step += 1 + + logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg)) + + +def validate(config, valid_loader, model, criterion, epoch, cur_step): + top1 = AverageMeter("top1") + top5 = AverageMeter("top5") + losses = AverageMeter("losses") + + model.eval() + + with torch.no_grad(): + for step, (X, y) in enumerate(valid_loader): + X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) + N = X.size(0) + + logits = model(X) + loss = criterion(logits, y) + + accuracy = utils.accuracy(logits, y, topk=(1, 5)) + losses.update(loss.item(), N) + top1.update(accuracy["acc1"], N) + top5.update(accuracy["acc5"], N) + + if step % config.log_frequency == 0 or step == len(valid_loader) - 1: + logger.info( + "Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " + "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( + epoch + 1, config.epochs, step, len(valid_loader) - 1, losses=losses, + top1=top1, top5=top5)) + + logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg)) + + return top1.avg + + +if __name__ == "__main__": + parser = ArgumentParser("darts") + parser.add_argument("--layers", default=20, type=int) + parser.add_argument("--batch-size", default=96, type=int) + parser.add_argument("--log-frequency", default=10, type=int) + parser.add_argument("--epochs", default=600, type=int) + parser.add_argument("--aux-weight", default=0.4, type=float) + parser.add_argument("--drop-path-prob", default=0.2, type=float) + parser.add_argument("--workers", default=4) + parser.add_argument("--grad-clip", default=5., type=float) + parser.add_argument("--arc-checkpoint", default="./checkpoints/epoch_0.json") + + args = parser.parse_args() + dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16) + + model = CNN(32, 3, 36, 10, args.layers, auxiliary=True) + apply_fixed_architecture(model, args.arc_checkpoint, device=device) + criterion = nn.CrossEntropyLoss() + + model.to(device) + criterion.to(device) + + optimizer = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=1E-6) + + train_loader = torch.utils.data.DataLoader(dataset_train, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, + pin_memory=True) + valid_loader = torch.utils.data.DataLoader(dataset_valid, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=True) + + best_top1 = 0. + for epoch in range(args.epochs): + drop_prob = args.drop_path_prob * epoch / args.epochs + model.drop_path_prob(drop_prob) + + # training + train(args, train_loader, model, optimizer, criterion, epoch) + + # validation + cur_step = (epoch + 1) * len(train_loader) + top1 = validate(args, valid_loader, model, criterion, epoch, cur_step) + best_top1 = max(best_top1, top1) + + lr_scheduler.step() + + logger.info("Final best Prec@1 = {:.4%}".format(best_top1)) diff --git a/examples/nas/darts/search.py b/examples/nas/darts/search.py index 75773cf5e0..02c720a60c 100644 --- a/examples/nas/darts/search.py +++ b/examples/nas/darts/search.py @@ -13,7 +13,7 @@ if __name__ == "__main__": parser = ArgumentParser("darts") parser.add_argument("--layers", default=8, type=int) - parser.add_argument("--batch-size", default=96, type=int) + parser.add_argument("--batch-size", default=64, type=int) parser.add_argument("--log-frequency", default=10, type=int) parser.add_argument("--epochs", default=50, type=int) args = parser.parse_args() @@ -36,4 +36,4 @@ batch_size=args.batch_size, log_frequency=args.log_frequency, callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")]) - trainer.train_and_validate() + trainer.train() diff --git a/examples/nas/enas/macro.py b/examples/nas/enas/macro.py index 48fcaaf03d..a9309f9079 100644 --- a/examples/nas/enas/macro.py +++ b/examples/nas/enas/macro.py @@ -6,7 +6,7 @@ class ENASLayer(mutables.MutableScope): - def __init__(self, key, num_prev_layers, in_filters, out_filters): + def __init__(self, key, prev_labels, in_filters, out_filters): super().__init__(key) self.in_filters = in_filters self.out_filters = out_filters @@ -18,16 +18,16 @@ def __init__(self, key, num_prev_layers, in_filters, out_filters): PoolBranch('avg', in_filters, out_filters, 3, 1, 1), PoolBranch('max', in_filters, out_filters, 3, 1, 1) ]) - if num_prev_layers > 0: - self.skipconnect = mutables.InputChoice(num_prev_layers, n_selected=None, reduction="sum") + if len(prev_labels) > 0: + self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None, reduction="sum") else: self.skipconnect = None self.batch_norm = nn.BatchNorm2d(out_filters, affine=False) - def forward(self, prev_layers, prev_labels): + def forward(self, prev_layers): out = self.mutable(prev_layers[-1]) if self.skipconnect is not None: - connection = self.skipconnect(prev_layers[:-1], tags=prev_labels) + connection = self.skipconnect(prev_layers[:-1]) if connection is not None: out += connection return self.batch_norm(out) @@ -53,11 +53,12 @@ def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10, self.layers = nn.ModuleList() self.pool_layers = nn.ModuleList() + labels = [] for layer_id in range(self.num_layers): + labels.append("layer_{}".format(layer_id)) if layer_id in self.pool_layers_idx: self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters)) - self.layers.append(ENASLayer("layer_{}".format(layer_id), layer_id, - self.out_filters, self.out_filters)) + self.layers.append(ENASLayer(labels[-1], labels[:-1], self.out_filters, self.out_filters)) self.gap = nn.AdaptiveAvgPool2d(1) self.dense = nn.Linear(self.out_filters, self.num_classes) @@ -66,12 +67,11 @@ def forward(self, x): bs = x.size(0) cur = self.stem(x) - layers, labels = [cur], [] + layers = [cur] for layer_id in range(self.num_layers): - cur = self.layers[layer_id](layers, labels) + cur = self.layers[layer_id](layers) layers.append(cur) - labels.append(self.layers[layer_id].key) if layer_id in self.pool_layers_idx: for i, layer in enumerate(layers): layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer) diff --git a/examples/nas/enas/micro.py b/examples/nas/enas/micro.py index 209abf2405..fabd3919ca 100644 --- a/examples/nas/enas/micro.py +++ b/examples/nas/enas/micro.py @@ -32,9 +32,9 @@ def forward(self, x): class Cell(nn.Module): - def __init__(self, cell_name, num_prev_layers, channels): + def __init__(self, cell_name, prev_labels, channels): super().__init__() - self.input_choice = mutables.InputChoice(num_prev_layers, n_selected=1, return_mask=True, + self.input_choice = mutables.InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True, key=cell_name + "_input") self.op_choice = mutables.LayerChoice([ SepConvBN(channels, channels, 3, 1), @@ -44,21 +44,21 @@ def __init__(self, cell_name, num_prev_layers, channels): nn.Identity() ], key=cell_name + "_op") - def forward(self, prev_layers, prev_labels): - chosen_input, chosen_mask = self.input_choice(prev_layers, tags=prev_labels) + def forward(self, prev_layers): + chosen_input, chosen_mask = self.input_choice(prev_layers) cell_out = self.op_choice(chosen_input) return cell_out, chosen_mask class Node(mutables.MutableScope): - def __init__(self, node_name, num_prev_layers, channels): + def __init__(self, node_name, prev_node_names, channels): super().__init__(node_name) - self.cell_x = Cell(node_name + "_x", num_prev_layers, channels) - self.cell_y = Cell(node_name + "_y", num_prev_layers, channels) + self.cell_x = Cell(node_name + "_x", prev_node_names, channels) + self.cell_y = Cell(node_name + "_y", prev_node_names, channels) - def forward(self, prev_layers, prev_labels): - out_x, mask_x = self.cell_x(prev_layers, prev_labels) - out_y, mask_y = self.cell_y(prev_layers, prev_labels) + def forward(self, prev_layers): + out_x, mask_x = self.cell_x(prev_layers) + out_y, mask_y = self.cell_y(prev_layers) return out_x + out_y, mask_x | mask_y @@ -93,8 +93,11 @@ def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduc self.num_nodes = num_nodes name_prefix = "reduce" if reduction else "normal" - self.nodes = nn.ModuleList([Node("{}_node_{}".format(name_prefix, i), - i + 2, out_channels) for i in range(num_nodes)]) + self.nodes = nn.ModuleList() + node_labels = [mutables.InputChoice.NO_KEY, mutables.InputChoice.NO_KEY] + for i in range(num_nodes): + node_labels.append("{}_node_{}".format(name_prefix, i)) + self.nodes.append(Node(node_labels[-1], node_labels[:-1], out_channels)) self.final_conv_w = nn.Parameter(torch.zeros(out_channels, self.num_nodes + 2, out_channels, 1, 1), requires_grad=True) self.bn = nn.BatchNorm2d(out_channels, affine=False) self.reset_parameters() @@ -106,14 +109,12 @@ def forward(self, pprev, prev): pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev) prev_nodes_out = [pprev_, prev_] - prev_nodes_labels = ["prev1", "prev2"] nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device) for i in range(self.num_nodes): - node_out, mask = self.nodes[i](prev_nodes_out, prev_nodes_labels) + node_out, mask = self.nodes[i](prev_nodes_out) nodes_used_mask[:mask.size(0)] |= mask prev_nodes_out.append(node_out) - prev_nodes_labels.append(self.nodes[i].key) - + unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1) unused_nodes = F.relu(unused_nodes) conv_weight = self.final_conv_w[:, ~nodes_used_mask, :, :, :] diff --git a/examples/nas/enas/search.py b/examples/nas/enas/search.py index 6e1bdec34c..35bc930333 100644 --- a/examples/nas/enas/search.py +++ b/examples/nas/enas/search.py @@ -13,7 +13,7 @@ if __name__ == "__main__": parser = ArgumentParser("enas") parser.add_argument("--batch-size", default=128, type=int) - parser.add_argument("--log-frequency", default=1, type=int) + parser.add_argument("--log-frequency", default=10, type=int) parser.add_argument("--search-for", choices=["macro", "micro"], default="macro") args = parser.parse_args() @@ -43,5 +43,6 @@ num_epochs=num_epochs, dataset_train=dataset_train, dataset_valid=dataset_valid, - log_frequency=args.log_frequency) - trainer.train_and_validate() + log_frequency=args.log_frequency, + mutator=mutator) + trainer.train() diff --git a/src/sdk/pynni/nni/nas/pytorch/base_mutator.py b/src/sdk/pynni/nni/nas/pytorch/base_mutator.py index dd2b844d24..550e449dfc 100644 --- a/src/sdk/pynni/nni/nas/pytorch/base_mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/base_mutator.py @@ -1,70 +1,127 @@ import logging import torch.nn as nn - -from nni.nas.pytorch.mutables import Mutable +from nni.nas.pytorch.mutables import Mutable, MutableScope, InputChoice +from nni.nas.pytorch.utils import StructuredMutableTreeNode logger = logging.getLogger(__name__) class BaseMutator(nn.Module): + """ + A mutator is responsible for mutating a graph by obtaining the search space from the network and implementing + callbacks that are called in ``forward`` in Mutables. + """ + def __init__(self, model): super().__init__() self.__dict__["model"] = model - self.before_parse_search_space() - self._parse_search_space() - self.after_parse_search_space() - - def before_parse_search_space(self): - pass - - def after_parse_search_space(self): - pass - - def _parse_search_space(self): - for name, mutable, _ in self.named_mutables(distinct=False): - mutable.name = name - mutable.set_mutator(self) + self._structured_mutables = self._parse_search_space(self.model) - def named_mutables(self, root=None, distinct=True): + def _parse_search_space(self, module, root=None, prefix="", memo=None, nested_detection=None): + if memo is None: + memo = set() if root is None: - root = self.model - # if distinct is true, the method will filter out those with duplicated keys - key2module = dict() - for name, module in root.named_modules(): + root = StructuredMutableTreeNode(None) + if module not in memo: + memo.add(module) if isinstance(module, Mutable): - module_distinct = False - if module.key in key2module: - assert key2module[module.key].similar(module), \ - "Mutable \"{}\" that share the same key must be similar to each other".format(module.key) - else: - module_distinct = True - key2module[module.key] = module - if distinct: - if module_distinct: - yield name, module - else: - yield name, module, module_distinct - - def __setattr__(self, key, value): - if key in ["model", "net", "network"]: - logger.warning("Think twice if you are including the network into mutator.") - return super().__setattr__(key, value) - + if nested_detection is not None: + raise RuntimeError("Cannot have nested search space. Error at {} in {}" + .format(module, nested_detection)) + module.name = prefix + module.set_mutator(self) + root = root.add_child(module) + if not isinstance(module, MutableScope): + nested_detection = module + if isinstance(module, InputChoice): + for k in module.choose_from: + if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]: + raise RuntimeError("'{}' required by '{}' not found in keys that appeared before, and is not NO_KEY." + .format(k, module.key)) + for name, submodule in module._modules.items(): + if submodule is None: + continue + submodule_prefix = prefix + ("." if prefix else "") + name + self._parse_search_space(submodule, root, submodule_prefix, memo=memo, + nested_detection=nested_detection) + return root + + @property + def mutables(self): + return self._structured_mutables + + @property def forward(self, *inputs): - raise NotImplementedError("Mutator is not forward-able") + raise RuntimeError("Forward is undefined for mutators.") def enter_mutable_scope(self, mutable_scope): + """ + Callback when forward of a MutableScope is entered. + + Parameters + ---------- + mutable_scope: MutableScope + + Returns + ------- + None + """ pass def exit_mutable_scope(self, mutable_scope): + """ + Callback when forward of a MutableScope is exited. + + Parameters + ---------- + mutable_scope: MutableScope + + Returns + ------- + None + """ pass def on_forward_layer_choice(self, mutable, *inputs): + """ + Callbacks of forward in LayerChoice. + + Parameters + ---------- + mutable: LayerChoice + inputs: list of torch.Tensor + + Returns + ------- + tuple of torch.Tensor and torch.Tensor + output tensor and mask + """ raise NotImplementedError - def on_forward_input_choice(self, mutable, tensor_list, tags): + def on_forward_input_choice(self, mutable, tensor_list): + """ + Callbacks of forward in InputChoice. + + Parameters + ---------- + mutable: InputChoice + tensor_list: list of torch.Tensor + + Returns + ------- + tuple of torch.Tensor and torch.Tensor + output tensor and mask + """ raise NotImplementedError def export(self): + """ + Export the data of all decisions. This should output the decisions of all the mutables, so that the whole + network can be fully determined with these decisions for further training from scratch. + + Returns + ------- + dict + """ raise NotImplementedError diff --git a/src/sdk/pynni/nni/nas/pytorch/base_trainer.py b/src/sdk/pynni/nni/nas/pytorch/base_trainer.py index 1248cc09e2..db1b033073 100644 --- a/src/sdk/pynni/nni/nas/pytorch/base_trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/base_trainer.py @@ -12,5 +12,9 @@ def validate(self): raise NotImplementedError @abstractmethod - def train_and_validate(self): + def export(self, file): + raise NotImplementedError + + @abstractmethod + def checkpoint(self): raise NotImplementedError diff --git a/src/sdk/pynni/nni/nas/pytorch/callbacks.py b/src/sdk/pynni/nni/nas/pytorch/callbacks.py index 2a76b3dab8..83ae62cde0 100644 --- a/src/sdk/pynni/nni/nas/pytorch/callbacks.py +++ b/src/sdk/pynni/nni/nas/pytorch/callbacks.py @@ -1,9 +1,6 @@ -import json import logging import os -import torch - _logger = logging.getLogger(__name__) @@ -44,26 +41,11 @@ def on_epoch_end(self, epoch): class ArchitectureCheckpoint(Callback): - class TorchTensorEncoder(json.JSONEncoder): - def default(self, o): # pylint: disable=method-hidden - if isinstance(o, torch.Tensor): - olist = o.tolist() - if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)): - _logger.warning("Every element in %s is either 0 or 1. " - "You might consider convert it into bool.", olist) - return olist - return super().default(o) - def __init__(self, checkpoint_dir, every="epoch"): super().__init__() assert every == "epoch" self.checkpoint_dir = checkpoint_dir os.makedirs(self.checkpoint_dir, exist_ok=True) - def _export_to_file(self, file): - mutator_export = self.mutator.export() - with open(file, "w") as f: - json.dump(mutator_export, f, indent=2, sort_keys=True, cls=self.TorchTensorEncoder) - def on_epoch_end(self, epoch): - self._export_to_file(os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch))) + self.trainer.export(os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch))) diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py b/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py index 7f2c9f9675..3bf08d285c 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py @@ -1,3 +1,2 @@ from .mutator import DartsMutator -from .trainer import DartsTrainer -from .scope import DartsNode \ No newline at end of file +from .trainer import DartsTrainer \ No newline at end of file diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py index 589847d2b6..91d739c0a3 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py @@ -2,35 +2,47 @@ from torch import nn as nn from torch.nn import functional as F -from nni.nas.pytorch.mutables import LayerChoice from nni.nas.pytorch.mutator import Mutator -from .scope import DartsNode +from nni.nas.pytorch.mutables import LayerChoice, InputChoice class DartsMutator(Mutator): - - def after_parse_search_space(self): + def __init__(self, model): + super().__init__(model) self.choices = nn.ParameterDict() - for _, mutable in self.named_mutables(): + for mutable in self.mutables: if isinstance(mutable, LayerChoice): - self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(len(mutable) + 1)) + self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1)) - def on_calc_layer_choice_mask(self, mutable: LayerChoice): - return F.softmax(self.choices[mutable.key], dim=-1)[:-1] + def device(self): + for v in self.choices.values(): + return v.device - def export(self): - result = super().export() - for _, darts_node in self.named_mutables(): - if isinstance(darts_node, DartsNode): - keys, edges_max = [], [] # key of all the layer choices in current node, and their best edge weight - for _, choice in self.named_mutables(darts_node): - if isinstance(choice, LayerChoice): - keys.append(choice.key) - max_val, index = torch.max(result[choice.key], 0) - edges_max.append(max_val) - result[choice.key] = F.one_hot(index, num_classes=len(result[choice.key])).view(-1).bool() - _, topk_edge_indices = torch.topk(torch.tensor(edges_max).view(-1), darts_node.limitation) # pylint: disable=not-callable - for i, key in enumerate(keys): - if i not in topk_edge_indices: - result[key] = torch.zeros_like(result[key]) + def sample_search(self): + result = dict() + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1] + elif isinstance(mutable, InputChoice): + result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) + return result + + def sample_final(self): + result = dict() + edges_max = dict() + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0) + edges_max[mutable.key] = max_val + result[mutable.key] = F.one_hot(index, num_classes=mutable.length).view(-1).bool() + for mutable in self.mutables: + if isinstance(mutable, InputChoice): + weights = torch.tensor([edges_max.get(src_key, 0.) for src_key in mutable.choose_from]) # pylint: disable=not-callable + _, topk_edge_indices = torch.topk(weights, mutable.n_chosen or mutable.n_candidates) + selected_multihot = [] + for i, src_key in enumerate(mutable.choose_from): + if i not in topk_edge_indices and src_key in result: + result[src_key] = torch.zeros_like(result[src_key]) # clear this choice to optimize calc graph + selected_multihot.append(i in topk_edge_indices) + result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable return result diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/scope.py b/src/sdk/pynni/nni/nas/pytorch/darts/scope.py deleted file mode 100644 index a2bf2b3cff..0000000000 --- a/src/sdk/pynni/nni/nas/pytorch/darts/scope.py +++ /dev/null @@ -1,11 +0,0 @@ -from nni.nas.pytorch.mutables import MutableScope - - -class DartsNode(MutableScope): - """ - At most `limitation` choice is activated in a `DartsNode` when exporting. - """ - - def __init__(self, key, limitation): - super().__init__(key) - self.limitation = limitation diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py index 464832eadf..c6b29de04a 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py @@ -4,7 +4,7 @@ from torch import nn as nn from nni.nas.pytorch.trainer import Trainer -from nni.nas.utils import AverageMeterGroup +from nni.nas.pytorch.utils import AverageMeterGroup from .mutator import DartsMutator @@ -13,9 +13,9 @@ def __init__(self, model, loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None): - super().__init__(model, loss, metrics, optimizer, num_epochs, - dataset_train, dataset_valid, batch_size, workers, device, log_frequency, - mutator if mutator is not None else DartsMutator(model), callbacks) + super().__init__(model, mutator if mutator is not None else DartsMutator(model), + loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, + batch_size, workers, device, log_frequency, callbacks) self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), 3.0E-4, betas=(0.5, 0.999), weight_decay=1.0E-3) n_train = len(self.dataset_train) @@ -31,6 +31,9 @@ def __init__(self, model, loss, metrics, batch_size=batch_size, sampler=valid_sampler, num_workers=workers) + self.test_loader = torch.utils.data.DataLoader(self.dataset_valid, + batch_size=batch_size, + num_workers=workers) def train_one_epoch(self, epoch): self.model.train() @@ -47,8 +50,8 @@ def train_one_epoch(self, epoch): # phase 1. child network step self.optimizer.zero_grad() - with self.mutator.forward_pass(): - logits = self.model(trn_X) + self.mutator.reset() + logits = self.model(trn_X) loss = self.loss(logits, trn_y) loss.backward() # gradient clipping @@ -76,10 +79,10 @@ def validate_one_epoch(self, epoch): self.mutator.eval() meters = AverageMeterGroup() with torch.no_grad(): - for step, (X, y) in enumerate(self.valid_loader): + self.mutator.reset() + for step, (X, y) in enumerate(self.test_loader): X, y = X.to(self.device), y.to(self.device) - with self.mutator.forward_pass(): - logits = self.model(X) + logits = self.model(X) metrics = self.metrics(logits, y) meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: @@ -93,8 +96,8 @@ def _unrolled_backward(self, trn_X, trn_y, val_X, val_y, backup_model, lr): v_model: backup model before this step lr: learning rate for virtual gradient step (same as net lr) """ - with self.mutator.forward_pass(): - loss = self.loss(self.model(val_X), val_y) + self.mutator.reset() + loss = self.loss(self.model(val_X), val_y) w_model = tuple(self.model.parameters()) w_ctrl = tuple(self.mutator.parameters()) w_grads = torch.autograd.grad(loss, w_model + w_ctrl) @@ -125,8 +128,8 @@ def _compute_hessian(self, model, dw, trn_X, trn_y): for p, d in zip(self.model.parameters(), dw): p += eps * d - with self.mutator.forward_pass(): - loss = self.loss(self.model(trn_X), trn_y) + self.mutator.reset() + loss = self.loss(self.model(trn_X), trn_y) if e > 0: dalpha_pos = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w+) } elif e < 0: diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py index 3bd32459b4..9d9a176352 100644 --- a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch.nn.functional as F -from nni.nas.pytorch.mutables import LayerChoice from nni.nas.pytorch.mutator import Mutator +from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope class StackedLSTMCell(nn.Module): @@ -27,15 +27,14 @@ def forward(self, inputs, hidden): class EnasMutator(Mutator): def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False, skip_target=0.4, branch_bias=0.25): + super().__init__(model) self.lstm_size = lstm_size self.lstm_num_layers = lstm_num_layers self.tanh_constant = tanh_constant self.cell_exit_extra_step = cell_exit_extra_step self.skip_target = skip_target self.branch_bias = branch_bias - super().__init__(model) - def before_parse_search_space(self): self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False) self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) @@ -45,9 +44,8 @@ def before_parse_search_space(self): self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none") self.bias_dict = nn.ParameterDict() - def after_parse_search_space(self): self.max_layer_choice = 0 - for _, mutable in self.named_mutables(): + for mutable in self.mutables: if isinstance(mutable, LayerChoice): if self.max_layer_choice == 0: self.max_layer_choice = mutable.length @@ -64,8 +62,29 @@ def is_conv(choice): self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size) self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False) - def before_pass(self): - super().before_pass() + def sample_search(self): + self._initialize() + self._sample(self.mutables) + return self._choices + + def sample_final(self): + return self.sample_search() + + def _sample(self, tree): + mutable = tree.mutable + if isinstance(mutable, LayerChoice) and mutable.key not in self._choices: + self._choices[mutable.key] = self._sample_layer_choice(mutable) + elif isinstance(mutable, InputChoice) and mutable.key not in self._choices: + self._choices[mutable.key] = self._sample_input_choice(mutable) + for child in tree.children: + self._sample(child) + if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid: + if self.cell_exit_extra_step: + self._lstm_next_step() + self._mark_anchor(mutable.key) + + def _initialize(self): + self._choices = dict() self._anchors_hid = dict() self._inputs = self.g_emb.data self._c = [torch.zeros((1, self.lstm_size), @@ -84,7 +103,7 @@ def _lstm_next_step(self): def _mark_anchor(self, key): self._anchors_hid[key] = self._h[-1] - def on_calc_layer_choice_mask(self, mutable): + def _sample_layer_choice(self, mutable): self._lstm_next_step() logit = self.soft(self._h[-1]) if self.tanh_constant is not None: @@ -94,14 +113,14 @@ def on_calc_layer_choice_mask(self, mutable): branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) log_prob = self.cross_entropy_loss(logit, branch_id) self.sample_log_prob += torch.sum(log_prob) - entropy = (log_prob * torch.exp(-log_prob)).detach() + entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type self.sample_entropy += torch.sum(entropy) self._inputs = self.embedding(branch_id) return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1) - def on_calc_input_choice_mask(self, mutable, tags): + def _sample_input_choice(self, mutable): query, anchors = [], [] - for label in tags: + for label in mutable.choose_from: if label not in self._anchors_hid: self._lstm_next_step() self._mark_anchor(label) # empty loop, fill not found @@ -113,8 +132,8 @@ def on_calc_input_choice_mask(self, mutable, tags): if self.tanh_constant is not None: query = self.tanh_constant * torch.tanh(query) - if mutable.n_selected is None: - logit = torch.cat([-query, query], 1) + if mutable.n_chosen is None: + logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) skip_prob = torch.sigmoid(logit) @@ -123,19 +142,14 @@ def on_calc_input_choice_mask(self, mutable, tags): log_prob = self.cross_entropy_loss(logit, skip) self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0) else: - assert mutable.n_selected == 1, "Input choice must select exactly one or any in ENAS." + assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS." logit = query.view(1, -1) index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) - skip = F.one_hot(index).view(-1) + skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1) log_prob = self.cross_entropy_loss(logit, index) self._inputs = anchors[index.item()] self.sample_log_prob += torch.sum(log_prob) - entropy = (log_prob * torch.exp(-log_prob)).detach() + entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type self.sample_entropy += torch.sum(entropy) return skip.bool() - - def exit_mutable_scope(self, mutable_scope): - if self.cell_exit_extra_step: - self._lstm_next_step() - self._mark_anchor(mutable_scope.key) diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py b/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py index 7d3e493782..1ed302ac7b 100644 --- a/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py @@ -2,7 +2,7 @@ import torch.optim as optim from nni.nas.pytorch.trainer import Trainer -from nni.nas.utils import AverageMeterGroup +from nni.nas.pytorch.utils import AverageMeterGroup from .mutator import EnasMutator @@ -12,9 +12,9 @@ def __init__(self, model, loss, metrics, reward_function, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4): - super().__init__(model, loss, metrics, optimizer, num_epochs, - dataset_train, dataset_valid, batch_size, workers, device, log_frequency, - mutator if mutator is not None else EnasMutator(model), callbacks) + super().__init__(model, mutator if mutator is not None else EnasMutator(model), + loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, + batch_size, workers, device, log_frequency, callbacks) self.reward_function = reward_function self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr) @@ -52,8 +52,9 @@ def train_one_epoch(self, epoch): x, y = x.to(self.device), y.to(self.device) self.optimizer.zero_grad() - with self.mutator.forward_pass(): - logits = self.model(x) + with torch.no_grad(): + self.mutator.reset() + logits = self.model(x) if isinstance(logits, tuple): logits, aux_logits = logits @@ -81,7 +82,8 @@ def train_one_epoch(self, epoch): for step, (x, y) in enumerate(self.valid_loader): x, y = x.to(self.device), y.to(self.device) - with self.mutator.forward_pass(): + self.mutator.reset() + with torch.no_grad(): logits = self.model(x) metrics = self.metrics(logits, y) reward = self.reward_function(logits, y) @@ -107,9 +109,9 @@ def train_one_epoch(self, epoch): self.mutator_optim.zero_grad() if self.log_frequency is not None and step % self.log_frequency == 0: - print("Mutator Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, - mutator_step // self.mutator_steps_aggregate, - self.mutator_steps, meters)) + print("RL Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, + mutator_step // self.mutator_steps_aggregate, + self.mutator_steps, meters)) mutator_step += 1 if mutator_step >= total_mutator_steps: break diff --git a/src/sdk/pynni/nni/nas/pytorch/fixed.py b/src/sdk/pynni/nni/nas/pytorch/fixed.py index 526d66b610..6b83aa0800 100644 --- a/src/sdk/pynni/nni/nas/pytorch/fixed.py +++ b/src/sdk/pynni/nni/nas/pytorch/fixed.py @@ -2,10 +2,12 @@ import torch +from nni.nas.pytorch.mutables import MutableScope from nni.nas.pytorch.mutator import Mutator class FixedArchitecture(Mutator): + def __init__(self, model, fixed_arc, strict=True): """ Initialize a fixed architecture mutator. @@ -20,39 +22,57 @@ def __init__(self, model, fixed_arc, strict=True): Force everything that appears in `fixed_arc` to be used at least once. """ super().__init__(model) - if isinstance(fixed_arc, str): - with open(fixed_arc, "r") as f: - fixed_arc = json.load(f.read()) self._fixed_arc = fixed_arc - self._strict = strict - - def _encode_tensor(self, data): - if isinstance(data, list): - if all(map(lambda o: isinstance(o, bool), data)): - return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable - else: - return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable - if isinstance(data, dict): - return {k: self._encode_tensor(v) for k, v in data.items()} - return data - - def before_pass(self): - self._unused_key = set(self._fixed_arc.keys()) - - def after_pass(self): - if self._strict: - if self._unused_key: - raise ValueError("{} are never used by the network. " - "Set strict=False if you want to disable this check.".format(self._unused_key)) - - def _check_key(self, key): - if key not in self._fixed_arc: - raise ValueError("\"{}\" is demanded by the network, but not found in saved architecture.".format(key)) - - def on_calc_layer_choice_mask(self, mutable): - self._check_key(mutable.key) - return self._fixed_arc[mutable.key] - - def on_calc_input_choice_mask(self, mutable, tags): - self._check_key(mutable.key) - return self._fixed_arc[mutable.key] + + mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)]) + fixed_arc_keys = set(self._fixed_arc.keys()) + if fixed_arc_keys - mutable_keys: + raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys)) + if mutable_keys - fixed_arc_keys: + raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys)) + + def sample_search(self): + return self._fixed_arc + + def sample_final(self): + return self._fixed_arc + + +def _encode_tensor(data, device): + if isinstance(data, list): + if all(map(lambda o: isinstance(o, bool), data)): + return torch.tensor(data, dtype=torch.bool, device=device) # pylint: disable=not-callable + else: + return torch.tensor(data, dtype=torch.float, device=device) # pylint: disable=not-callable + if isinstance(data, dict): + return {k: _encode_tensor(v, device) for k, v in data.items()} + return data + + +def apply_fixed_architecture(model, fixed_arc_path, device=None): + """ + Load architecture from `fixed_arc_path` and apply to model. + + Parameters + ---------- + model: torch.nn.Module + Model with mutables. + fixed_arc_path: str + Path to the JSON that stores the architecture. + device: torch.device + Architecture weights will be transfered to `device`. + + Returns + ------- + FixedArchitecture + """ + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if isinstance(fixed_arc_path, str): + with open(fixed_arc_path, "r") as f: + fixed_arc = json.load(f) + fixed_arc = _encode_tensor(fixed_arc, device) + architecture = FixedArchitecture(model, fixed_arc) + architecture.to(device) + architecture.reset() diff --git a/src/sdk/pynni/nni/nas/pytorch/mutables.py b/src/sdk/pynni/nni/nas/pytorch/mutables.py index 16b73b903d..79cde1cf3f 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutables.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutables.py @@ -1,6 +1,6 @@ import torch.nn as nn -from nni.nas.utils import global_mutable_counting +from nni.nas.pytorch.utils import global_mutable_counting class Mutable(nn.Module): @@ -37,7 +37,7 @@ def set_mutator(self, mutator): self.__dict__["mutator"] = mutator def forward(self, *inputs): - raise NotImplementedError("Mutable forward must be implemented.") + raise NotImplementedError @property def key(self): @@ -51,9 +51,6 @@ def name(self): def name(self, name): self._name = name - def similar(self, other): - return type(self) == type(other) - def _check_built(self): if not hasattr(self, "mutator"): raise ValueError( @@ -66,19 +63,17 @@ def __repr__(self): class MutableScope(Mutable): """ - Mutable scope labels a subgraph to help mutators make better decisions. Mutators get notified when a mutable scope - is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch - corresponding events, and do status dump or update. + Mutable scope labels a subgraph/submodule to help mutators make better decisions. + Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope`` + and ``exit_mutable_scope`` to catch corresponding events, and do status dump or update. """ def __init__(self, key): super().__init__(key=key) - def build(self): - self.mutator.on_init_mutable_scope(self) - def __call__(self, *args, **kwargs): try: + self._check_built() self.mutator.enter_mutable_scope(self) return super().__call__(*args, **kwargs) finally: @@ -93,43 +88,92 @@ def __init__(self, op_candidates, reduction="mean", return_mask=False, key=None) self.reduction = reduction self.return_mask = return_mask - def __len__(self): - return len(self.choices) - def forward(self, *inputs): out, mask = self.mutator.on_forward_layer_choice(self, *inputs) if self.return_mask: return out, mask return out - def similar(self, other): - return type(self) == type(other) and self.length == other.length - class InputChoice(Mutable): - def __init__(self, n_candidates, n_selected=None, reduction="mean", return_mask=False, key=None): + """ + Input choice selects `n_chosen` inputs from `choose_from` (contains `n_candidates` keys). For beginners, + use `n_candidates` instead of `choose_from` is a safe option. To get the most power out of it, you might want to + know about `choose_from`. + + The keys in `choose_from` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones. + The keys are designed to be the keys of the sources. To help mutators make better decisions, + mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the + output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g., + ``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a + module/submodule, it needs to be annotated with a key: that's where a ``MutableScope`` is needed. + """ + + NO_KEY = "" + + def __init__(self, n_candidates=None, choose_from=None, n_chosen=None, + reduction="mean", return_mask=False, key=None): + """ + Initialization. + + Parameters + ---------- + n_candidates: int + Number of inputs to choose from. + choose_from: list of str + List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled. + If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates` + number of empty string. + n_chosen: int + Recommended inputs to choose. If None, mutator is instructed to select any. + reduction: str + `mean`, `concat`, `sum` or `none`. + return_mask: bool + If `return_mask`, return output tensor and a mask. Otherwise return tensor only. + key: str + Key of the input choice. + """ super().__init__(key=key) + # precondition check + assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \ + "must be not None." + if choose_from is not None and n_candidates is None: + n_candidates = len(choose_from) + elif choose_from is None and n_candidates is not None: + choose_from = [self.NO_KEY] * n_candidates + assert n_candidates == len(choose_from), "Number of candidates must be equal to the length of `choose_from`." assert n_candidates > 0, "Number of candidates must be greater than 0." + assert n_chosen is None or 0 <= n_chosen <= n_candidates, "Expected selected number must be None or no more " \ + "than number of candidates." + self.n_candidates = n_candidates - self.n_selected = n_selected + self.choose_from = choose_from + self.n_chosen = n_chosen self.reduction = reduction self.return_mask = return_mask - def build(self): - self.mutator.on_init_input_choice(self) - - def forward(self, optional_inputs, tags=None): + def forward(self, optional_inputs): + """ + Forward method of LayerChoice. + + Parameters + ---------- + optional_inputs: list or dict + Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of + `choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as + `choose_from`. + + Returns + ------- + tuple of torch.Tensor and torch.Tensor or torch.Tensor + """ + optional_input_list = optional_inputs + if isinstance(optional_inputs, dict): + optional_input_list = [optional_inputs[tag] for tag in self.choose_from] + assert isinstance(optional_input_list, list), "Optional input list must be a list" assert len(optional_inputs) == self.n_candidates, \ "Length of the input list must be equal to number of candidates." - if tags is None: - tags = [""] * self.n_candidates - else: - assert len(tags) == self.n_candidates, "Length of tags must be equal to number of candidates." - out, mask = self.mutator.on_forward_input_choice(self, optional_inputs, tags) + out, mask = self.mutator.on_forward_input_choice(self, optional_input_list) if self.return_mask: return out, mask return out - - def similar(self, other): - return type(self) == type(other) and \ - self.n_candidates == other.n_candidates and self.n_selected and other.n_selected diff --git a/src/sdk/pynni/nni/nas/pytorch/mutator.py b/src/sdk/pynni/nni/nas/pytorch/mutator.py index 21d39545e7..80608c6925 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutator.py @@ -1,46 +1,59 @@ -from contextlib import contextmanager - import torch -import torch.nn as nn from nni.nas.pytorch.base_mutator import BaseMutator -class Mutator(BaseMutator, nn.Module): +class Mutator(BaseMutator): - def export(self): - if self._in_forward_pass: - raise RuntimeError("Still in forward pass. Exporting might induce incompleteness.") - if not self._cache: - raise RuntimeError("No running history found. You need to call your model at least once before exporting. " - "You might also want to check if there are no valid mutables in your model.") - return self._cache - - @contextmanager - def forward_pass(self): - self._in_forward_pass = True + def __init__(self, model): + super().__init__(model) self._cache = dict() - self.before_pass() - try: - yield self - finally: - self.after_pass() - self._in_forward_pass = False - def before_pass(self): - pass + def sample_search(self): + """ + Override to implement this method to iterate over mutables and make decisions. + + Returns + ------- + dict + A mapping from key of mutables to decisions. + """ + raise NotImplementedError + + def sample_final(self): + """ + Override to implement this method to iterate over mutables and make decisions that is final + for export and retraining. - def after_pass(self): - pass + Returns + ------- + dict + A mapping from key of mutables to decisions. + """ + raise NotImplementedError - def _check_in_forward_pass(self): - if not hasattr(self, "_in_forward_pass") or not self._in_forward_pass: - raise ValueError("Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call " - "super().before_pass() and after_pass() in your override method?") + def reset(self): + """ + Reset the mutator by call the `sample_search` to resample (for search). + + Returns + ------- + None + """ + self._cache = self.sample_search() + + def export(self): + """ + Resample (for final) and return results. + + Returns + ------- + dict + """ + return self.sample_final() def on_forward_layer_choice(self, mutable, *inputs): """ - Callback of layer choice forward. Override if you are an advanced user. On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers (either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`. @@ -54,18 +67,17 @@ def on_forward_layer_choice(self, mutable, *inputs): ------- tuple of torch.Tensor and torch.Tensor """ - self._check_in_forward_pass() def _map_fn(op, *inputs): return op(*inputs) - mask = self._cache.setdefault(mutable.key, self.on_calc_layer_choice_mask(mutable)) + mask = self._get_decision(mutable) + assert len(mask) == len(mutable.choices) out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask) return self._tensor_reduction(mutable.reduction, out), mask - def on_forward_input_choice(self, mutable, tensor_list, tags): + def on_forward_input_choice(self, mutable, tensor_list): """ - Callback of input choice forward. Override if you are an advanced user. On default, this method calls :meth:`on_calc_input_choice_mask` with `tags` to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the @@ -81,48 +93,11 @@ def on_forward_input_choice(self, mutable, tensor_list, tags): ------- tuple of torch.Tensor and torch.Tensor """ - self._check_in_forward_pass() - mask = self._cache.setdefault(mutable.key, self.on_calc_input_choice_mask(mutable, tags)) + mask = self._get_decision(mutable) + assert len(mask) == mutable.n_candidates out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask) return self._tensor_reduction(mutable.reduction, out), mask - def on_calc_layer_choice_mask(self, mutable): - """ - Recommended to override. Calculate a mask tensor for a layer choice. - - Parameters - ---------- - mutable: LayerChoice - Corresponding layer choice object. - - Returns - ------- - torch.Tensor - Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool, - the numbers are treated as switch. - """ - raise NotImplementedError("Layer choice mask calculation must be implemented") - - def on_calc_input_choice_mask(self, mutable, tags): - """ - Recommended to override. Calculate a mask tensor for a input choice. - - Parameters - ---------- - mutable: InputChoice - Corresponding input choice object. - tags: list of string - The name of labels of input tensors given by user. Usually it's a - :class:`~nni.nas.pytorch.mutables.MutableScope` marked by user. - - Returns - ------- - torch.Tensor - Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool, - the numbers are treated as switch. - """ - raise NotImplementedError("Input choice mask calculation must be implemented") - def _select_with_mask(self, map_fn, candidates, mask): if "BoolTensor" in mask.type(): out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m] @@ -146,3 +121,20 @@ def _tensor_reduction(self, reduction_type, tensor_list): if reduction_type == "concat": return torch.cat(tensor_list, dim=1) raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type)) + + def _get_decision(self, mutable): + """ + By default, this method checks whether `mutable.key` is already in the decision cache, + and returns the result without double-check. + + Parameters + ---------- + mutable: Mutable + + Returns + ------- + any + """ + if mutable.key not in self._cache: + raise ValueError("\"{}\" not found in decision cache.".format(mutable.key)) + return self._cache[mutable.key] diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py index 6e385b1170..da31b3cc69 100644 --- a/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py @@ -11,14 +11,14 @@ class PdartsMutator(DartsMutator): - def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches=None): + def __init__(self, pdarts_epoch_index, pdarts_num_to_drop, switches=None): self.pdarts_epoch_index = pdarts_epoch_index self.pdarts_num_to_drop = pdarts_num_to_drop self.switches = switches - super(PdartsMutator, self).__init__(model) + super(PdartsMutator, self).__init__() - def before_build(self, model): + def before_build(self): self.choices = nn.ParameterDict() if self.switches is None: self.switches = {} diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py index 4d9c231143..d4ef2bbb8e 100644 --- a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py @@ -35,8 +35,7 @@ def train(self): layers = self.layers+self.pdarts_num_layers[epoch] model, loss, model_optim, _ = self.model_creator( layers, n_nodes) - mutator = PdartsMutator( - model, epoch, self.pdarts_num_to_drop, switches) + mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches) # pylint: disable=too-many-function-args self.trainer = DartsTrainer(model, loss=loss, optimizer=model_optim, mutator=mutator, **self.darts_parameters) diff --git a/src/sdk/pynni/nni/nas/pytorch/trainer.py b/src/sdk/pynni/nni/nas/pytorch/trainer.py index ab18e6c6e5..a4954a0747 100644 --- a/src/sdk/pynni/nni/nas/pytorch/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/trainer.py @@ -1,24 +1,39 @@ +import json +import logging from abc import abstractmethod import torch from .base_trainer import BaseTrainer +_logger = logging.getLogger(__name__) + + +class TorchTensorEncoder(json.JSONEncoder): + def default(self, o): # pylint: disable=method-hidden + if isinstance(o, torch.Tensor): + olist = o.tolist() + if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)): + _logger.warning("Every element in %s is either 0 or 1. " + "You might consider convert it into bool.", olist) + return olist + return super().default(o) + class Trainer(BaseTrainer): - def __init__(self, model, loss, metrics, optimizer, num_epochs, - dataset_train, dataset_valid, batch_size, workers, device, log_frequency, - mutator, callbacks): + def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs, + dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device self.model = model + self.mutator = mutator self.loss = loss + self.metrics = metrics self.optimizer = optimizer - self.mutator = mutator self.model.to(self.device) - self.loss.to(self.device) self.mutator.to(self.device) + self.loss.to(self.device) self.num_epochs = num_epochs self.dataset_train = dataset_train @@ -38,7 +53,7 @@ def train_one_epoch(self, epoch): def validate_one_epoch(self, epoch): pass - def _train(self, validate): + def train(self, validate=True): for epoch in range(self.num_epochs): for callback in self.callbacks: callback.on_epoch_begin(epoch) @@ -55,11 +70,13 @@ def _train(self, validate): for callback in self.callbacks: callback.on_epoch_end(epoch) - def train_and_validate(self): - self._train(True) - - def train(self): - self._train(False) - def validate(self): self.validate_one_epoch(-1) + + def export(self, file): + mutator_export = self.mutator.export() + with open(file, "w") as f: + json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder) + + def checkpoint(self): + raise NotImplementedError("Not implemented yet") diff --git a/src/sdk/pynni/nni/nas/pytorch/utils.py b/src/sdk/pynni/nni/nas/pytorch/utils.py new file mode 100644 index 0000000000..d3a4292155 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/utils.py @@ -0,0 +1,107 @@ +from collections import OrderedDict + +_counter = 0 + + +def global_mutable_counting(): + global _counter + _counter += 1 + return _counter + + +class AverageMeterGroup: + + def __init__(self): + self.meters = OrderedDict() + + def update(self, data): + for k, v in data.items(): + if k not in self.meters: + self.meters[k] = AverageMeter(k, ":4f") + self.meters[k].update(v) + + def __str__(self): + return " ".join(str(v) for _, v in self.meters.items()) + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class StructuredMutableTreeNode: + """ + A structured representation of a search space. + A search space comes with a root (with `None` stored in its `mutable`), and a bunch of children in its `children`. + This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet, + the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a + ``Mutable`` (other than ``MutableScope``). + """ + + def __init__(self, mutable): + self.mutable = mutable + self.children = [] + + def add_child(self, mutable): + self.children.append(StructuredMutableTreeNode(mutable)) + return self.children[-1] + + def type(self): + return type(self.mutable) + + def __iter__(self): + return self.traverse() + + def traverse(self, order="pre", deduplicate=True, memo=None): + """ + Return a generator that generates a list of mutables in this tree. + + Parameters + ---------- + order: str + pre or post. If pre, current mutable is yield before children. Otherwise after. + deduplicate: bool + If true, mutables with the same key will not appear after the first appearance. + memo: dict + An auxiliary variable to make deduplicate happen. + + Returns + ------- + generator of Mutable + """ + if memo is None: + memo = set() + assert order in ["pre", "post"] + if order == "pre": + if self.mutable is not None: + if not deduplicate or self.mutable.key not in memo: + memo.add(self.mutable.key) + yield self.mutable + for child in self.children: + for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo): + yield m + if order == "post": + if self.mutable is not None: + if not deduplicate or self.mutable.key not in memo: + memo.add(self.mutable.key) + yield self.mutable From 6d6f952463fb2eac87f775a8432a2e353c59a1c2 Mon Sep 17 00:00:00 2001 From: Chi Song <27178119+squirrelsc@users.noreply.github.com> Date: Fri, 22 Nov 2019 11:40:03 +0800 Subject: [PATCH 08/10] pdarts update (#1753) --- docs/en_US/AdvancedFeature/MultiPhase.md | 2 +- docs/en_US/NAS/Overview.md | 139 +++++++------ docs/en_US/Tutorial/SearchSpaceSpec.md | 8 - docs/en_US/advanced.rst | 2 - examples/nas/darts/retrain.py | 14 +- examples/nas/darts/search.py | 17 +- examples/nas/enas/search.py | 13 ++ examples/nas/pdarts/datasets.py | 25 --- examples/nas/pdarts/main.py | 65 ------ examples/nas/pdarts/search.py | 69 +++++++ .../pynni/nni/nas/pytorch/darts/cnn_cell.py | 69 ------- .../nni/nas/pytorch/darts/cnn_network.py | 73 ------- .../pynni/nni/nas/pytorch/darts/cnn_ops.py | 189 ------------------ .../pynni/nni/nas/pytorch/darts/trainer.py | 11 +- src/sdk/pynni/nni/nas/pytorch/enas/trainer.py | 14 +- src/sdk/pynni/nni/nas/pytorch/modules.py | 9 - src/sdk/pynni/nni/nas/pytorch/mutables.py | 7 +- .../pynni/nni/nas/pytorch/pdarts/__init__.py | 3 + .../pynni/nni/nas/pytorch/pdarts/mutator.py | 50 ++--- .../pynni/nni/nas/pytorch/pdarts/trainer.py | 56 ++++-- src/sdk/pynni/nni/nas/pytorch/trainer.py | 5 +- src/sdk/pynni/nni/nas/utils.py | 49 ----- 22 files changed, 274 insertions(+), 615 deletions(-) delete mode 100644 examples/nas/pdarts/datasets.py delete mode 100644 examples/nas/pdarts/main.py create mode 100644 examples/nas/pdarts/search.py delete mode 100644 src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py delete mode 100644 src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py delete mode 100644 src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py delete mode 100644 src/sdk/pynni/nni/nas/pytorch/modules.py delete mode 100644 src/sdk/pynni/nni/nas/utils.py diff --git a/docs/en_US/AdvancedFeature/MultiPhase.md b/docs/en_US/AdvancedFeature/MultiPhase.md index c9727bcdcc..4cdb3a7a99 100644 --- a/docs/en_US/AdvancedFeature/MultiPhase.md +++ b/docs/en_US/AdvancedFeature/MultiPhase.md @@ -79,7 +79,7 @@ With this information, the tuner could know which trial is requesting a configur ### Tuners support multi-phase experiments: -[TPE](../Tuner/HyperoptTuner.md), [Random](../Tuner/HyperoptTuner.md), [Anneal](../Tuner/HyperoptTuner.md), [Evolution](../Tuner/EvolutionTuner.md), [SMAC](../Tuner/SmacTuner.md), [NetworkMorphism](../Tuner/NetworkmorphismTuner.md), [MetisTuner](../Tuner/MetisTuner.md), [BOHB](../Tuner/BohbAdvisor.md), [Hyperband](../Tuner/HyperbandAdvisor.md), [ENAS tuner](https://github.com/countif/enas_nni/blob/master/nni/examples/tuners/enas/nni_controller_ptb.py). +[TPE](../Tuner/HyperoptTuner.md), [Random](../Tuner/HyperoptTuner.md), [Anneal](../Tuner/HyperoptTuner.md), [Evolution](../Tuner/EvolutionTuner.md), [SMAC](../Tuner/SmacTuner.md), [NetworkMorphism](../Tuner/NetworkmorphismTuner.md), [MetisTuner](../Tuner/MetisTuner.md), [BOHB](../Tuner/BohbAdvisor.md), [Hyperband](../Tuner/HyperbandAdvisor.md). ### Training services support multi-phase experiment: [Local Machine](../TrainingService/LocalMode.md), [Remote Servers](../TrainingService/RemoteMachineMode.md), [OpenPAI](../TrainingService/PaiMode.md) diff --git a/docs/en_US/NAS/Overview.md b/docs/en_US/NAS/Overview.md index 92b06b413f..4e48483df3 100644 --- a/docs/en_US/NAS/Overview.md +++ b/docs/en_US/NAS/Overview.md @@ -1,62 +1,77 @@ -# Neural Architecture Search (NAS) on NNI - -Automatic neural architecture search is taking an increasingly important role on finding better models. Recent research works have proved the feasibility of automatic NAS, and also found some models that could beat manually designed and tuned models. Some of representative works are [NASNet][2], [ENAS][1], [DARTS][3], [Network Morphism][4], and [Evolution][5]. There are new innovations keeping emerging. - -However, it takes great efforts to implement NAS algorithms, and it is hard to reuse code base of existing algorithms in new one. To facilitate NAS innovations (e.g., design and implement new NAS models, compare different NAS models side-by-side), an easy-to-use and flexible programming interface is crucial. - -With this motivation, our ambition is to provide a unified architecture in NNI, to accelerate innovations on NAS, and apply state-of-art algorithms on real world problems faster. - -## Supported algorithms - -NNI supports below NAS algorithms now, and being adding more. User can reproduce an algorithm, or use it on owned dataset. we also encourage user to implement other algorithms with [NNI API](#use-nni-api), to benefit more people. - -Note, these algorithms run standalone without nnictl, and supports PyTorch only. - -### DARTS - -The main contribution of [DARTS: Differentiable Architecture Search][3] on algorithm is to introduce a novel algorithm for differentiable network architecture search on bilevel optimization. - -#### Usage - -```bash -### In case NNI code is not cloned. -git clone https://github.com/Microsoft/nni.git - -cd examples/nas/darts -python search.py -``` - -### P-DARTS - -[Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation](https://arxiv.org/abs/1904.12760) bases on DARTS(#DARTS). It main contribution on algorithm is to introduce an efficient algorithm which allows the depth of searched architectures to grow gradually during the training procedure. - -#### Usage - -```bash -### In case NNI code is not cloned. -git clone https://github.com/Microsoft/nni.git - -cd examples/nas/pdarts -python main.py -``` - -## Use NNI API - -NOTE, we are trying to support various NAS algorithms with unified programming interface, and it's in very experimental stage. It means the current programing interface may be updated significantly. - -*previous [NAS annotation](../AdvancedFeature/GeneralNasInterfaces.md) interface will be deprecated soon.* - -### Programming interface - -The programming interface of designing and searching a model is often demanded in two scenarios. - -1. When designing a neural network, there may be multiple operation choices on a layer, sub-model, or connection, and it's undetermined which one or combination performs best. So it needs an easy way to express the candidate layers or sub-models. -2. When applying NAS on a neural network, it needs an unified way to express the search space of architectures, so that it doesn't need to update trial code for different searching algorithms. - -NNI proposed API is [here](https://github.com/microsoft/nni/tree/dev-nas-refactor/src/sdk/pynni/nni/nas/pytorch). And [here](https://github.com/microsoft/nni/tree/dev-nas-refactor/examples/nas/darts) is an example of NAS implementation, which bases on NNI proposed interface. - -[1]: https://arxiv.org/abs/1802.03268 -[2]: https://arxiv.org/abs/1707.07012 -[3]: https://arxiv.org/abs/1806.09055 -[4]: https://arxiv.org/abs/1806.10282 -[5]: https://arxiv.org/abs/1703.01041 +# Neural Architecture Search (NAS) on NNI + +Automatic neural architecture search is taking an increasingly important role on finding better models. Recent research works have proved the feasibility of automatic NAS, and also found some models that could beat manually designed and tuned models. Some of representative works are [NASNet][2], [ENAS][1], [DARTS][3], [Network Morphism][4], and [Evolution][5]. There are new innovations keeping emerging. + +However, it takes great efforts to implement NAS algorithms, and it is hard to reuse code base of existing algorithms in new one. To facilitate NAS innovations (e.g., design and implement new NAS models, compare different NAS models side-by-side), an easy-to-use and flexible programming interface is crucial. + +With this motivation, our ambition is to provide a unified architecture in NNI, to accelerate innovations on NAS, and apply state-of-art algorithms on real world problems faster. + +## Supported algorithms + +NNI supports below NAS algorithms now and being adding more. User can reproduce an algorithm or use it on owned dataset. we also encourage user to implement other algorithms with [NNI API](#use-nni-api), to benefit more people. + +Note, these algorithms run standalone without nnictl, and supports PyTorch only. + +### Dependencies + +* Install latest NNI +* PyTorch 1.2+ +* git + +### DARTS + +The main contribution of [DARTS: Differentiable Architecture Search][3] on algorithm is to introduce a novel algorithm for differentiable network architecture search on bilevel optimization. + +#### Usage + +```bash +# In case NNI code is not cloned. If the code is cloned already, ignore this line and enter code folder. +git clone https://github.com/Microsoft/nni.git + +# search the best architecture +cd examples/nas/darts +python3 search.py + +# train the best architecture +python3 retrain.py --arc-checkpoint ./checkpoints/epoch_49.json +``` + +### P-DARTS + +[Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation](https://arxiv.org/abs/1904.12760) bases on [DARTS](#DARTS). It's contribution on algorithm is to introduce an efficient algorithm which allows the depth of searched architectures to grow gradually during the training procedure. + +#### Usage + +```bash +# In case NNI code is not cloned. If the code is cloned already, ignore this line and enter code folder. +git clone https://github.com/Microsoft/nni.git + +# search the best architecture +cd examples/nas/pdarts +python3 search.py + +# train the best architecture, it's the same progress as darts. +cd examples/nas/darts +python3 retrain.py --arc-checkpoint ./checkpoints/epoch_2.json +``` + +## Use NNI API + +NOTE, we are trying to support various NAS algorithms with unified programming interface, and it's in very experimental stage. It means the current programing interface may be updated significantly. + +*previous [NAS annotation](../AdvancedFeature/GeneralNasInterfaces.md) interface will be deprecated soon.* + +### Programming interface + +The programming interface of designing and searching a model is often demanded in two scenarios. + +1. When designing a neural network, there may be multiple operation choices on a layer, sub-model, or connection, and it's undetermined which one or combination performs best. So, it needs an easy way to express the candidate layers or sub-models. +2. When applying NAS on a neural network, it needs an unified way to express the search space of architectures, so that it doesn't need to update trial code for different searching algorithms. + +NNI proposed API is [here](https://github.com/microsoft/nni/tree/master/src/sdk/pynni/nni/nas/pytorch). And [here](https://github.com/microsoft/nni/tree/master/examples/nas/darts) is an example of NAS implementation, which bases on NNI proposed interface. + +[1]: https://arxiv.org/abs/1802.03268 +[2]: https://arxiv.org/abs/1707.07012 +[3]: https://arxiv.org/abs/1806.09055 +[4]: https://arxiv.org/abs/1806.10282 +[5]: https://arxiv.org/abs/1703.01041 diff --git a/docs/en_US/Tutorial/SearchSpaceSpec.md b/docs/en_US/Tutorial/SearchSpaceSpec.md index fd1781716f..eb5d39315c 100644 --- a/docs/en_US/Tutorial/SearchSpaceSpec.md +++ b/docs/en_US/Tutorial/SearchSpaceSpec.md @@ -73,12 +73,6 @@ All types of sampling strategies and their parameter are listed here: * Which means the variable value is a value like `round(exp(normal(mu, sigma)) / q) * q` * Suitable for a discrete variable with respect to which the objective is smooth and gets smoother with the size of the variable, which is bounded from one side. -* `{"_type": "mutable_layer", "_value": {mutable_layer_infomation}}` - * Type for [Neural Architecture Search Space][1]. Value is also a dictionary, which contains key-value pairs representing respectively name and search space of each mutable_layer. - * For now, users can only use this type of search space with annotation, which means that there is no need to define a json file for search space since it will be automatically generated according to the annotation in trial code. - * The following HPO tuners can be adapted to tune this search space: TPE, Random, Anneal, Evolution, Grid Search, - Hyperband and BOHB. - * For detailed usage, please refer to [General NAS Interfaces][1]. ## Search Space Types Supported by Each Tuner @@ -105,5 +99,3 @@ Known Limitations: * Only Random Search/TPE/Anneal/Evolution tuner supports nested search space * We do not support nested search space "Hyper Parameter" in visualization now, the enhancement is being considered in [#1110](https://github.com/microsoft/nni/issues/1110), any suggestions or discussions or contributions are warmly welcomed - -[1]: ../AdvancedFeature/GeneralNasInterfaces.md diff --git a/docs/en_US/advanced.rst b/docs/en_US/advanced.rst index d9192cc869..e38f634969 100644 --- a/docs/en_US/advanced.rst +++ b/docs/en_US/advanced.rst @@ -3,5 +3,3 @@ Advanced Features .. toctree:: MultiPhase<./AdvancedFeature/MultiPhase> - AdvancedNas<./AdvancedFeature/AdvancedNas> - NAS Programming Interface<./AdvancedFeature/GeneralNasInterfaces> \ No newline at end of file diff --git a/examples/nas/darts/retrain.py b/examples/nas/darts/retrain.py index 5c8fabf8d0..e3167376f9 100644 --- a/examples/nas/darts/retrain.py +++ b/examples/nas/darts/retrain.py @@ -1,4 +1,5 @@ import logging +import time from argparse import ArgumentParser import torch @@ -10,8 +11,17 @@ from nni.nas.pytorch.fixed import apply_fixed_architecture from nni.nas.pytorch.utils import AverageMeter -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +logger = logging.getLogger() + +fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' +logging.Formatter.converter = time.localtime +formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p') + +std_out_info = logging.StreamHandler() +std_out_info.setFormatter(formatter) +logger.setLevel(logging.INFO) +logger.addHandler(std_out_info) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/examples/nas/darts/search.py b/examples/nas/darts/search.py index 02c720a60c..d9bdf0c7b5 100644 --- a/examples/nas/darts/search.py +++ b/examples/nas/darts/search.py @@ -1,14 +1,27 @@ +import logging +import time from argparse import ArgumentParser -import datasets import torch import torch.nn as nn +import datasets from model import CNN -from nni.nas.pytorch.callbacks import LearningRateScheduler, ArchitectureCheckpoint +from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint, + LearningRateScheduler) from nni.nas.pytorch.darts import DartsTrainer from utils import accuracy +logger = logging.getLogger() + +fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' +logging.Formatter.converter = time.localtime +formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p') + +std_out_info = logging.StreamHandler() +std_out_info.setFormatter(formatter) +logger.setLevel(logging.INFO) +logger.addHandler(std_out_info) if __name__ == "__main__": parser = ArgumentParser("darts") diff --git a/examples/nas/enas/search.py b/examples/nas/enas/search.py index 35bc930333..6fade75164 100644 --- a/examples/nas/enas/search.py +++ b/examples/nas/enas/search.py @@ -1,3 +1,5 @@ +import logging +import time from argparse import ArgumentParser import torch @@ -10,6 +12,17 @@ from nni.nas.pytorch.callbacks import LearningRateScheduler, ArchitectureCheckpoint from utils import accuracy, reward_accuracy +logger = logging.getLogger() + +fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' +logging.Formatter.converter = time.localtime +formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p') + +std_out_info = logging.StreamHandler() +std_out_info.setFormatter(formatter) +logger.setLevel(logging.INFO) +logger.addHandler(std_out_info) + if __name__ == "__main__": parser = ArgumentParser("enas") parser.add_argument("--batch-size", default=128, type=int) diff --git a/examples/nas/pdarts/datasets.py b/examples/nas/pdarts/datasets.py deleted file mode 100644 index 8fe0ab0fbf..0000000000 --- a/examples/nas/pdarts/datasets.py +++ /dev/null @@ -1,25 +0,0 @@ -from torchvision import transforms -from torchvision.datasets import CIFAR10 - - -def get_dataset(cls): - MEAN = [0.49139968, 0.48215827, 0.44653124] - STD = [0.24703233, 0.24348505, 0.26158768] - transf = [ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip() - ] - normalize = [ - transforms.ToTensor(), - transforms.Normalize(MEAN, STD) - ] - - train_transform = transforms.Compose(transf + normalize) - valid_transform = transforms.Compose(normalize) - - if cls == "cifar10": - dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform) - dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform) - else: - raise NotImplementedError - return dataset_train, dataset_valid diff --git a/examples/nas/pdarts/main.py b/examples/nas/pdarts/main.py deleted file mode 100644 index 68a59c8856..0000000000 --- a/examples/nas/pdarts/main.py +++ /dev/null @@ -1,65 +0,0 @@ -from argparse import ArgumentParser - -import datasets -import torch -import torch.nn as nn -import nni.nas.pytorch as nas -from nni.nas.pytorch.pdarts import PdartsTrainer -from nni.nas.pytorch.darts import CnnNetwork, CnnCell - - -def accuracy(output, target, topk=(1,)): - """ Computes the precision@k for the specified values of k """ - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - # one-hot case - if target.ndimension() > 1: - target = target.max(1)[1] - - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = dict() - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0) - res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() - return res - - -if __name__ == "__main__": - parser = ArgumentParser("darts") - parser.add_argument("--layers", default=5, type=int) - parser.add_argument('--add_layers', action='append', - default=[0, 6, 12], help='add layers') - parser.add_argument("--nodes", default=4, type=int) - parser.add_argument("--batch-size", default=128, type=int) - parser.add_argument("--log-frequency", default=1, type=int) - args = parser.parse_args() - - dataset_train, dataset_valid = datasets.get_dataset("cifar10") - - def model_creator(layers, n_nodes): - model = CnnNetwork(3, 16, 10, layers, n_nodes=n_nodes, cell_type=CnnCell) - loss = nn.CrossEntropyLoss() - - model_optim = torch.optim.SGD(model.parameters(), 0.025, - momentum=0.9, weight_decay=3.0E-4) - n_epochs = 50 - lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, n_epochs, eta_min=0.001) - return model, loss, model_optim, lr_scheduler - - trainer = PdartsTrainer(model_creator, - metrics=lambda output, target: accuracy(output, target, topk=(1,)), - num_epochs=50, - pdarts_num_layers=[0, 6, 12], - pdarts_num_to_drop=[3, 2, 2], - dataset_train=dataset_train, - dataset_valid=dataset_valid, - layers=args.layers, - n_nodes=args.nodes, - batch_size=args.batch_size, - log_frequency=args.log_frequency) - trainer.train() - trainer.export() diff --git a/examples/nas/pdarts/search.py b/examples/nas/pdarts/search.py new file mode 100644 index 0000000000..5d38fda0db --- /dev/null +++ b/examples/nas/pdarts/search.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +import sys +import time +from argparse import ArgumentParser + +import torch +import torch.nn as nn + +from nni.nas.pytorch.callbacks import ArchitectureCheckpoint +from nni.nas.pytorch.pdarts import PdartsTrainer + +# prevent it to be reordered. +if True: + sys.path.append('../darts') + from utils import accuracy + from model import CNN + import datasets + +logger = logging.getLogger() + +fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' +logging.Formatter.converter = time.localtime +formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p') + +std_out_info = logging.StreamHandler() +std_out_info.setFormatter(formatter) +logger.setLevel(logging.INFO) +logger.addHandler(std_out_info) + +if __name__ == "__main__": + parser = ArgumentParser("pdarts") + parser.add_argument('--add_layers', action='append', + default=[0, 6, 12], help='add layers') + parser.add_argument("--nodes", default=4, type=int) + parser.add_argument("--layers", default=5, type=int) + parser.add_argument("--batch-size", default=64, type=int) + parser.add_argument("--log-frequency", default=1, type=int) + parser.add_argument("--epochs", default=50, type=int) + args = parser.parse_args() + + logger.info("loading data") + dataset_train, dataset_valid = datasets.get_dataset("cifar10") + + def model_creator(layers): + model = CNN(32, 3, 16, 10, layers, n_nodes=args.nodes) + criterion = nn.CrossEntropyLoss() + + optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001) + + return model, criterion, optim, lr_scheduler + + logger.info("initializing trainer") + trainer = PdartsTrainer(model_creator, + layers=args.layers, + metrics=lambda output, target: accuracy(output, target, topk=(1,)), + pdarts_num_layers=[0, 6, 12], + pdarts_num_to_drop=[3, 2, 2], + num_epochs=args.epochs, + dataset_train=dataset_train, + dataset_valid=dataset_valid, + batch_size=args.batch_size, + log_frequency=args.log_frequency, + callbacks=[ArchitectureCheckpoint("./checkpoints")]) + logger.info("training") + trainer.train() diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py deleted file mode 100644 index 69dc28e8f0..0000000000 --- a/src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py +++ /dev/null @@ -1,69 +0,0 @@ - -import torch -import torch.nn as nn - -import nni.nas.pytorch as nas -from nni.nas.pytorch.modules import RankedModule - -from .cnn_ops import OPS, PRIMITIVES, FactorizedReduce, StdConv - - -class CnnCell(RankedModule): - """ - Cell for search. - """ - - def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): - """ - Initialization a search cell. - - Parameters - ---------- - n_nodes: int - Number of nodes in current DAG. - channels_pp: int - Number of output channels from previous previous cell. - channels_p: int - Number of output channels from previous cell. - channels: int - Number of channels that will be used in the current DAG. - reduction_p: bool - Flag for whether the previous cell is reduction cell or not. - reduction: bool - Flag for whether the current cell is reduction cell or not. - """ - super(CnnCell, self).__init__(rank=1, reduction=reduction) - self.n_nodes = n_nodes - - # If previous cell is reduction cell, current input size does not match with - # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. - if reduction_p: - self.preproc0 = FactorizedReduce(channels_pp, channels, affine=False) - else: - self.preproc0 = StdConv(channels_pp, channels, 1, 1, 0, affine=False) - self.preproc1 = StdConv(channels_p, channels, 1, 1, 0, affine=False) - - # generate dag - self.mutable_ops = nn.ModuleList() - for depth in range(self.n_nodes): - self.mutable_ops.append(nn.ModuleList()) - for i in range(2 + depth): # include 2 input nodes - # reduction should be used only for input node - stride = 2 if reduction and i < 2 else 1 - m_ops = [] - for primitive in PRIMITIVES: - op = OPS[primitive](channels, stride, False) - m_ops.append(op) - op = nas.mutables.LayerChoice(m_ops, key="r{}_d{}_i{}".format(reduction, depth, i)) - self.mutable_ops[depth].append(op) - - def forward(self, s0, s1): - # s0, s1 are the outputs of previous previous cell and previous cell, respectively. - tensors = [self.preproc0(s0), self.preproc1(s1)] - for ops in self.mutable_ops: - assert len(ops) == len(tensors) - cur_tensor = sum(op(tensor) for op, tensor in zip(ops, tensors)) - tensors.append(cur_tensor) - - output = torch.cat(tensors[2:], dim=1) - return output diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py deleted file mode 100644 index d126e3353e..0000000000 --- a/src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py +++ /dev/null @@ -1,73 +0,0 @@ - -import torch.nn as nn - -from .cnn_cell import CnnCell - - -class CnnNetwork(nn.Module): - """ - Search CNN model - """ - - def __init__(self, in_channels, channels, n_classes, n_layers, n_nodes=4, stem_multiplier=3, cell_type=CnnCell): - """ - Initializing a search channelsNN. - - Parameters - ---------- - in_channels: int - Number of channels in images. - channels: int - Number of channels used in the network. - n_classes: int - Number of classes. - n_layers: int - Number of cells in the whole network. - n_nodes: int - Number of nodes in a cell. - stem_multiplier: int - Multiplier of channels in STEM. - """ - super().__init__() - self.in_channels = in_channels - self.channels = channels - self.n_classes = n_classes - self.n_layers = n_layers - - c_cur = stem_multiplier * self.channels - self.stem = nn.Sequential( - nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), - nn.BatchNorm2d(c_cur) - ) - - # for the first cell, stem is used for both s0 and s1 - # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. - channels_pp, channels_p, c_cur = c_cur, c_cur, channels - - self.cells = nn.ModuleList() - reduction_p, reduction = False, False - for i in range(n_layers): - reduction_p, reduction = reduction, False - # Reduce featuremap size and double channels in 1/3 and 2/3 layer. - if i in [n_layers // 3, 2 * n_layers // 3]: - c_cur *= 2 - reduction = True - - cell = cell_type(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction) - self.cells.append(cell) - c_cur_out = c_cur * n_nodes - channels_pp, channels_p = channels_p, c_cur_out - - self.gap = nn.AdaptiveAvgPool2d(1) - self.linear = nn.Linear(channels_p, n_classes) - - def forward(self, x): - s0 = s1 = self.stem(x) - - for cell in self.cells: - s0, s1 = s1, cell(s0, s1) - - out = self.gap(s1) - out = out.view(out.size(0), -1) # flatten - logits = self.linear(out) - return logits diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py deleted file mode 100644 index 02b4a3a94c..0000000000 --- a/src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py +++ /dev/null @@ -1,189 +0,0 @@ -import torch -import torch.nn as nn - -PRIMITIVES = [ - 'none', - 'max_pool_3x3', - 'avg_pool_3x3', - 'skip_connect', # identity - 'sep_conv_3x3', - 'sep_conv_5x5', - 'dil_conv_3x3', - 'dil_conv_5x5', -] - -OPS = { - 'none': lambda C, stride, affine: Zero(stride), - 'avg_pool_3x3': lambda C, stride, affine: PoolBN('avg', C, 3, stride, 1, affine=affine), - 'max_pool_3x3': lambda C, stride, affine: PoolBN('max', C, 3, stride, 1, affine=affine), - 'skip_connect': lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), - 'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), - 'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), - 'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), - 'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5 - 'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9 - 'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine) -} - - -def drop_path_(x, drop_prob, training): - if training and drop_prob > 0.: - keep_prob = 1. - drop_prob - # per data point mask; assuming x in cuda. - mask = torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob) - x.div_(keep_prob).mul_(mask) - - return x - - -class DropPath_(nn.Module): - def __init__(self, p=0.): - """ [!] DropPath is inplace module - Args: - p: probability of an path to be zeroed. - """ - super().__init__() - self.p = p - - def extra_repr(self): - return 'p={}, inplace'.format(self.p) - - def forward(self, x): - drop_path_(x, self.p, self.training) - - return x - - -class PoolBN(nn.Module): - """ - AvgPool or MaxPool - BN - """ - - def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True): - """ - Args: - pool_type: 'max' or 'avg' - """ - super().__init__() - if pool_type.lower() == 'max': - self.pool = nn.MaxPool2d(kernel_size, stride, padding) - elif pool_type.lower() == 'avg': - self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) - else: - raise ValueError() - - self.bn = nn.BatchNorm2d(C, affine=affine) - - def forward(self, x): - out = self.pool(x) - out = self.bn(out) - return out - - -class StdConv(nn.Module): - """ Standard conv - ReLU - Conv - BN - """ - - def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): - super().__init__() - self.net = nn.Sequential( - nn.ReLU(), - nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False), - nn.BatchNorm2d(C_out, affine=affine) - ) - - def forward(self, x): - return self.net(x) - - -class FacConv(nn.Module): - """ Factorized conv - ReLU - Conv(Kx1) - Conv(1xK) - BN - """ - - def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True): - super().__init__() - self.net = nn.Sequential( - nn.ReLU(), - nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False), - nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False), - nn.BatchNorm2d(C_out, affine=affine) - ) - - def forward(self, x): - return self.net(x) - - -class DilConv(nn.Module): - """ (Dilated) depthwise separable conv - ReLU - (Dilated) depthwise separable - Pointwise - BN - If dilation == 2, 3x3 conv => 5x5 receptive field - 5x5 conv => 9x9 receptive field - """ - - def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): - super().__init__() - self.net = nn.Sequential( - nn.ReLU(), - nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, bias=False), - nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), - nn.BatchNorm2d(C_out, affine=affine) - ) - - def forward(self, x): - return self.net(x) - - -class SepConv(nn.Module): - """ Depthwise separable conv - DilConv(dilation=1) * 2 - """ - - def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): - super().__init__() - self.net = nn.Sequential( - DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine), - DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine) - ) - - def forward(self, x): - return self.net(x) - - -class Identity(nn.Module): - - def forward(self, x): - return x - - -class Zero(nn.Module): - def __init__(self, stride): - super().__init__() - self.stride = stride - - def forward(self, x): - if self.stride == 1: - return x * 0. - - # re-sizing by stride - return x[:, :, ::self.stride, ::self.stride] * 0. - - -class FactorizedReduce(nn.Module): - """ - Reduce feature map size by factorized pointwise(stride=2). - """ - - def __init__(self, C_in, C_out, affine=True): - super().__init__() - self.relu = nn.ReLU() - self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) - self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) - self.bn = nn.BatchNorm2d(C_out, affine=affine) - - def forward(self, x): - x = self.relu(x) - out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) - out = self.bn(out) - return out diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py index c6b29de04a..6392962111 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py @@ -1,12 +1,17 @@ import copy +import logging import torch from torch import nn as nn from nni.nas.pytorch.trainer import Trainer from nni.nas.pytorch.utils import AverageMeterGroup + from .mutator import DartsMutator +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + class DartsTrainer(Trainer): def __init__(self, model, loss, metrics, @@ -72,7 +77,8 @@ def train_one_epoch(self, epoch): metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: - print("Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, step, len(self.train_loader), meters)) + logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch+1, + self.num_epochs, step+1, len(self.train_loader), meters) def validate_one_epoch(self, epoch): self.model.eval() @@ -86,7 +92,8 @@ def validate_one_epoch(self, epoch): metrics = self.metrics(logits, y) meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: - print("Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, step, len(self.valid_loader), meters)) + logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch+1, + self.num_epochs, step+1, len(self.test_loader), meters) def _unrolled_backward(self, trn_X, trn_y, val_X, val_y, backup_model, lr): """ diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py b/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py index 1ed302ac7b..49052d6b08 100644 --- a/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py @@ -1,3 +1,4 @@ +import logging import torch import torch.optim as optim @@ -6,6 +7,10 @@ from .mutator import EnasMutator +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + class EnasTrainer(Trainer): def __init__(self, model, loss, metrics, reward_function, optimizer, num_epochs, dataset_train, dataset_valid, @@ -70,8 +75,8 @@ def train_one_epoch(self, epoch): meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: - print("Model Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, - step, len(self.train_loader), meters)) + logger.info("Model Epoch [%s/%s] Step [%s/%s] %s", epoch, + self.num_epochs, step, len(self.train_loader), meters) # Train sampler (mutator) self.model.eval() @@ -109,9 +114,8 @@ def train_one_epoch(self, epoch): self.mutator_optim.zero_grad() if self.log_frequency is not None and step % self.log_frequency == 0: - print("RL Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, - mutator_step // self.mutator_steps_aggregate, - self.mutator_steps, meters)) + logger.info("RL Epoch [%s/%s] Step [%s/%s] %s", epoch, self.num_epochs, + mutator_step // self.mutator_steps_aggregate, self.mutator_steps, meters) mutator_step += 1 if mutator_step >= total_mutator_steps: break diff --git a/src/sdk/pynni/nni/nas/pytorch/modules.py b/src/sdk/pynni/nni/nas/pytorch/modules.py deleted file mode 100644 index 6570220e13..0000000000 --- a/src/sdk/pynni/nni/nas/pytorch/modules.py +++ /dev/null @@ -1,9 +0,0 @@ - -from torch import nn as nn - - -class RankedModule(nn.Module): - def __init__(self, rank=None, reduction=False): - super(RankedModule, self).__init__() - self.rank = rank - self.reduction = reduction diff --git a/src/sdk/pynni/nni/nas/pytorch/mutables.py b/src/sdk/pynni/nni/nas/pytorch/mutables.py index 79cde1cf3f..4dbf514af8 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutables.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutables.py @@ -1,7 +1,12 @@ +import logging + import torch.nn as nn from nni.nas.pytorch.utils import global_mutable_counting +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + class Mutable(nn.Module): """ @@ -20,7 +25,7 @@ def __init__(self, key=None): if key is not None: if not isinstance(key, str): key = str(key) - print("Warning: key \"{}\" is not string, converted to string.".format(key)) + logger.warning("Warning: key \"%s\" is not string, converted to string.", key) self._key = key else: self._key = self.__class__.__name__ + str(global_mutable_counting()) diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py index 27dd912ab3..d1d17764ba 100644 --- a/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py @@ -1 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + from .trainer import PdartsTrainer diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py index da31b3cc69..5862e9714b 100644 --- a/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py @@ -1,8 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import copy import numpy as np -import torch -from torch import nn as nn from torch.nn import functional as F from nni.nas.pytorch.darts import DartsMutator @@ -11,24 +12,27 @@ class PdartsMutator(DartsMutator): - def __init__(self, pdarts_epoch_index, pdarts_num_to_drop, switches=None): + def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}): self.pdarts_epoch_index = pdarts_epoch_index self.pdarts_num_to_drop = pdarts_num_to_drop - self.switches = switches + if switches is None: + self.switches = {} + else: + self.switches = switches - super(PdartsMutator, self).__init__() + super(PdartsMutator, self).__init__(model) - def before_build(self): - self.choices = nn.ParameterDict() - if self.switches is None: - self.switches = {} + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + + switches = self.switches.get(mutable.key, [True for j in range(mutable.length)]) + + for index in range(len(switches)-1, -1, -1): + if switches[index] == False: + del(mutable.choices[index]) + mutable.length -= 1 - def named_mutables(self, model): - key2module = dict() - for name, module in model.named_modules(): - if isinstance(module, LayerChoice): - key2module[module.key] = module - yield name, module, True + self.switches[mutable.key] = switches def drop_paths(self): for key in self.switches: @@ -49,22 +53,6 @@ def drop_paths(self): switches[idxs[idx]] = False return self.switches - def on_init_layer_choice(self, mutable: LayerChoice): - switches = self.switches.get( - mutable.key, [True for j in range(mutable.length)]) - - for index in range(len(switches)-1, -1, -1): - if switches[index] == False: - del(mutable.choices[index]) - mutable.length -= 1 - - self.switches[mutable.key] = switches - - self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length)) - - def on_calc_layer_choice_mask(self, mutable: LayerChoice): - return F.softmax(self.choices[mutable.key], dim=-1) - def get_min_k(self, input_in, k): index = [] for _ in range(k): diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py index d4ef2bbb8e..af31da08fc 100644 --- a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py @@ -1,17 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import logging +from nni.nas.pytorch.callbacks import LearningRateScheduler from nni.nas.pytorch.darts import DartsTrainer -from nni.nas.pytorch.trainer import Trainer +from nni.nas.pytorch.trainer import BaseTrainer from .mutator import PdartsMutator +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) -class PdartsTrainer(Trainer): - def __init__(self, model_creator, metrics, num_epochs, dataset_train, dataset_valid, - layers=5, n_nodes=4, pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 2], - mutator=None, batch_size=64, workers=4, device=None, log_frequency=None): +class PdartsTrainer(BaseTrainer): + + def __init__(self, model_creator, layers, metrics, + num_epochs, dataset_train, dataset_valid, + pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 2], + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None): + super(PdartsTrainer, self).__init__() self.model_creator = model_creator self.layers = layers - self.n_nodes = n_nodes self.pdarts_num_layers = pdarts_num_layers self.pdarts_num_to_drop = pdarts_num_to_drop self.pdarts_epoch = len(pdarts_num_to_drop) @@ -25,29 +33,41 @@ def __init__(self, model_creator, metrics, num_epochs, dataset_train, dataset_va "device": device, "log_frequency": log_frequency } + self.callbacks = callbacks if callbacks is not None else [] def train(self): layers = self.layers - n_nodes = self.n_nodes switches = None for epoch in range(self.pdarts_epoch): layers = self.layers+self.pdarts_num_layers[epoch] - model, loss, model_optim, _ = self.model_creator( - layers, n_nodes) - mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches) # pylint: disable=too-many-function-args + model, criterion, optim, lr_scheduler = self.model_creator(layers) + self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches) + + for callback in self.callbacks: + callback.build(model, self.mutator, self) + callback.on_epoch_begin(epoch) + + darts_callbacks = [] + if lr_scheduler is not None: + darts_callbacks.append(LearningRateScheduler(lr_scheduler)) - self.trainer = DartsTrainer(model, loss=loss, optimizer=model_optim, - mutator=mutator, **self.darts_parameters) - print("start pdrats training %s..." % epoch) + self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim, + callbacks=darts_callbacks, **self.darts_parameters) + logger.info("start pdarts training %s...", epoch) self.trainer.train() - # with open('log/parameters_%d.txt' % epoch, "w") as f: - # f.write(str(model.parameters)) + switches = self.mutator.drop_paths() - switches = mutator.drop_paths() + for callback in self.callbacks: + callback.on_epoch_end(epoch) + + def validate(self): + self.model.validate() def export(self): - if (self.trainer is not None) and hasattr(self.trainer, "export"): - self.trainer.export() + self.mutator.export() + + def checkpoint(self): + raise NotImplementedError("Not implemented yet") diff --git a/src/sdk/pynni/nni/nas/pytorch/trainer.py b/src/sdk/pynni/nni/nas/pytorch/trainer.py index a4954a0747..9195631a60 100644 --- a/src/sdk/pynni/nni/nas/pytorch/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/trainer.py @@ -7,6 +7,7 @@ from .base_trainer import BaseTrainer _logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) class TorchTensorEncoder(json.JSONEncoder): @@ -59,12 +60,12 @@ def train(self, validate=True): callback.on_epoch_begin(epoch) # training - print("Epoch {} Training".format(epoch)) + _logger.info("Epoch %d Training", epoch) self.train_one_epoch(epoch) if validate: # validation - print("Epoch {} Validating".format(epoch)) + _logger.info("Epoch %d Validating", epoch) self.validate_one_epoch(epoch) for callback in self.callbacks: diff --git a/src/sdk/pynni/nni/nas/utils.py b/src/sdk/pynni/nni/nas/utils.py deleted file mode 100644 index 5000946e7e..0000000000 --- a/src/sdk/pynni/nni/nas/utils.py +++ /dev/null @@ -1,49 +0,0 @@ -from collections import OrderedDict - -_counter = 0 - - -def global_mutable_counting(): - global _counter - _counter += 1 - return _counter - - -class AverageMeterGroup(object): - - def __init__(self): - self.meters = OrderedDict() - - def update(self, data): - for k, v in data.items(): - if k not in self.meters: - self.meters[k] = AverageMeter(k, ":4f") - self.meters[k].update(v) - - def __str__(self): - return " ".join(str(v) for _, v in self.meters.items()) - - -class AverageMeter(object): - """Computes and stores the average and current value""" - - def __init__(self, name, fmt=':f'): - self.name = name - self.fmt = fmt - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - def __str__(self): - fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' - return fmtstr.format(**self.__dict__) From 73b2221b5eb4fd21802e6bf41e21d5df8ef9bf2c Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 22 Nov 2019 17:13:32 +0800 Subject: [PATCH 09/10] Update DARTS trainer and fix docstring issues (#1772) --- examples/nas/.gitignore | 1 + examples/nas/darts/model.py | 3 +- examples/nas/darts/ops.py | 39 ++--- examples/nas/darts/retrain.py | 22 ++- examples/nas/darts/search.py | 7 +- examples/nas/enas/macro.py | 2 +- examples/nas/enas/search.py | 4 +- src/sdk/pynni/nni/nas/pytorch/base_mutator.py | 27 ++-- src/sdk/pynni/nni/nas/pytorch/callbacks.py | 2 +- .../pynni/nni/nas/pytorch/darts/mutator.py | 4 +- .../pynni/nni/nas/pytorch/darts/trainer.py | 137 +++++++++++------- src/sdk/pynni/nni/nas/pytorch/enas/mutator.py | 3 +- src/sdk/pynni/nni/nas/pytorch/enas/trainer.py | 10 +- src/sdk/pynni/nni/nas/pytorch/fixed.py | 13 +- src/sdk/pynni/nni/nas/pytorch/mutables.py | 24 +-- src/sdk/pynni/nni/nas/pytorch/mutator.py | 24 +-- .../pynni/nni/nas/pytorch/pdarts/mutator.py | 2 +- .../pynni/nni/nas/pytorch/pdarts/trainer.py | 4 +- src/sdk/pynni/nni/nas/pytorch/trainer.py | 34 +++++ src/sdk/pynni/nni/nas/pytorch/utils.py | 18 ++- 20 files changed, 236 insertions(+), 144 deletions(-) diff --git a/examples/nas/.gitignore b/examples/nas/.gitignore index 9ba06a7ca3..8eeb0c2a3f 100644 --- a/examples/nas/.gitignore +++ b/examples/nas/.gitignore @@ -1,2 +1,3 @@ data checkpoints +runs diff --git a/examples/nas/darts/model.py b/examples/nas/darts/model.py index 6a9afe6ff3..a6a86f4a72 100644 --- a/examples/nas/darts/model.py +++ b/examples/nas/darts/model.py @@ -48,7 +48,7 @@ def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect): ops.SepConv(channels, channels, 3, stride, 1, affine=False), ops.SepConv(channels, channels, 5, stride, 2, affine=False), ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False), - ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False), + ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False) ], key=choice_keys[-1])) self.drop_path = ops.DropPath_() @@ -57,6 +57,7 @@ def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect): def forward(self, prev_nodes): assert len(self.ops) == len(prev_nodes) out = [op(node) for op, node in zip(self.ops, prev_nodes)] + out = [self.drop_path(o) if o is not None else None for o in out] return self.input_switch(out) diff --git a/examples/nas/darts/ops.py b/examples/nas/darts/ops.py index 2fef9fec19..9b74c346f9 100644 --- a/examples/nas/darts/ops.py +++ b/examples/nas/darts/ops.py @@ -4,9 +4,13 @@ class DropPath_(nn.Module): def __init__(self, p=0.): - """ [!] DropPath is inplace module - Args: - p: probability of an path to be zeroed. + """ + DropPath is inplace module. + + Parameters + ---------- + p : float + Probability of an path to be zeroed. """ super().__init__() self.p = p @@ -26,13 +30,9 @@ def forward(self, x): class PoolBN(nn.Module): """ - AvgPool or MaxPool - BN + AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`. """ def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True): - """ - Args: - pool_type: 'max' or 'avg' - """ super().__init__() if pool_type.lower() == 'max': self.pool = nn.MaxPool2d(kernel_size, stride, padding) @@ -50,8 +50,8 @@ def forward(self, x): class StdConv(nn.Module): - """ Standard conv - ReLU - Conv - BN + """ + Standard conv: ReLU - Conv - BN """ def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): super().__init__() @@ -66,8 +66,8 @@ def forward(self, x): class FacConv(nn.Module): - """ Factorized conv - ReLU - Conv(Kx1) - Conv(1xK) - BN + """ + Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN """ def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True): super().__init__() @@ -83,10 +83,10 @@ def forward(self, x): class DilConv(nn.Module): - """ (Dilated) depthwise separable conv - ReLU - (Dilated) depthwise separable - Pointwise - BN - If dilation == 2, 3x3 conv => 5x5 receptive field - 5x5 conv => 9x9 receptive field + """ + (Dilated) depthwise separable conv. + ReLU - (Dilated) depthwise separable - Pointwise - BN. + If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field. """ def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): super().__init__() @@ -103,8 +103,9 @@ def forward(self, x): class SepConv(nn.Module): - """ Depthwise separable conv - DilConv(dilation=1) * 2 + """ + Depthwise separable conv. + DilConv(dilation=1) * 2. """ def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): super().__init__() @@ -119,7 +120,7 @@ def forward(self, x): class FactorizedReduce(nn.Module): """ - Reduce feature map size by factorized pointwise(stride=2). + Reduce feature map size by factorized pointwise (stride=2). """ def __init__(self, C_in, C_out, affine=True): super().__init__() diff --git a/examples/nas/darts/retrain.py b/examples/nas/darts/retrain.py index e3167376f9..904f3248fc 100644 --- a/examples/nas/darts/retrain.py +++ b/examples/nas/darts/retrain.py @@ -4,12 +4,13 @@ import torch import torch.nn as nn +from nni.nas.pytorch.fixed import apply_fixed_architecture +from nni.nas.pytorch.utils import AverageMeter +from torch.utils.tensorboard import SummaryWriter import datasets import utils from model import CNN -from nni.nas.pytorch.fixed import apply_fixed_architecture -from nni.nas.pytorch.utils import AverageMeter logger = logging.getLogger() @@ -23,6 +24,7 @@ logger.addHandler(std_out_info) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +writer = SummaryWriter() def train(config, train_loader, model, optimizer, criterion, epoch): @@ -33,6 +35,7 @@ def train(config, train_loader, model, optimizer, criterion, epoch): cur_step = epoch * len(train_loader) cur_lr = optimizer.param_groups[0]['lr'] logger.info("Epoch %d LR %.6f", epoch, cur_lr) + writer.add_scalar("lr", cur_lr, global_step=cur_step) model.train() @@ -54,6 +57,9 @@ def train(config, train_loader, model, optimizer, criterion, epoch): losses.update(loss.item(), bs) top1.update(accuracy["acc1"], bs) top5.update(accuracy["acc5"], bs) + writer.add_scalar("loss/train", loss.item(), global_step=cur_step) + writer.add_scalar("acc1/train", accuracy["acc1"], global_step=cur_step) + writer.add_scalar("acc5/train", accuracy["acc5"], global_step=cur_step) if step % config.log_frequency == 0 or step == len(train_loader) - 1: logger.info( @@ -77,15 +83,15 @@ def validate(config, valid_loader, model, criterion, epoch, cur_step): with torch.no_grad(): for step, (X, y) in enumerate(valid_loader): X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) - N = X.size(0) + bs = X.size(0) logits = model(X) loss = criterion(logits, y) accuracy = utils.accuracy(logits, y, topk=(1, 5)) - losses.update(loss.item(), N) - top1.update(accuracy["acc1"], N) - top5.update(accuracy["acc5"], N) + losses.update(loss.item(), bs) + top1.update(accuracy["acc1"], bs) + top5.update(accuracy["acc5"], bs) if step % config.log_frequency == 0 or step == len(valid_loader) - 1: logger.info( @@ -94,6 +100,10 @@ def validate(config, valid_loader, model, criterion, epoch, cur_step): epoch + 1, config.epochs, step, len(valid_loader) - 1, losses=losses, top1=top1, top5=top5)) + writer.add_scalar("loss/test", losses.avg, global_step=cur_step) + writer.add_scalar("acc1/test", top1.avg, global_step=cur_step) + writer.add_scalar("acc5/test", top5.avg, global_step=cur_step) + logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg)) return top1.avg diff --git a/examples/nas/darts/search.py b/examples/nas/darts/search.py index d9bdf0c7b5..f25db7c7e4 100644 --- a/examples/nas/darts/search.py +++ b/examples/nas/darts/search.py @@ -7,8 +7,7 @@ import datasets from model import CNN -from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint, - LearningRateScheduler) +from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback from nni.nas.pytorch.darts import DartsTrainer from utils import accuracy @@ -29,6 +28,7 @@ parser.add_argument("--batch-size", default=64, type=int) parser.add_argument("--log-frequency", default=10, type=int) parser.add_argument("--epochs", default=50, type=int) + parser.add_argument("--unrolled", default=False, action="store_true") args = parser.parse_args() dataset_train, dataset_valid = datasets.get_dataset("cifar10") @@ -48,5 +48,6 @@ dataset_valid=dataset_valid, batch_size=args.batch_size, log_frequency=args.log_frequency, - callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")]) + unrolled=args.unrolled, + callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")]) trainer.train() diff --git a/examples/nas/enas/macro.py b/examples/nas/enas/macro.py index a9309f9079..bfac9b17c9 100644 --- a/examples/nas/enas/macro.py +++ b/examples/nas/enas/macro.py @@ -19,7 +19,7 @@ def __init__(self, key, prev_labels, in_filters, out_filters): PoolBranch('max', in_filters, out_filters, 3, 1, 1) ]) if len(prev_labels) > 0: - self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None, reduction="sum") + self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None) else: self.skipconnect = None self.batch_norm = nn.BatchNorm2d(out_filters, affine=False) diff --git a/examples/nas/enas/search.py b/examples/nas/enas/search.py index 6fade75164..5188bfeb9f 100644 --- a/examples/nas/enas/search.py +++ b/examples/nas/enas/search.py @@ -9,7 +9,7 @@ from macro import GeneralNetwork from micro import MicroNetwork from nni.nas.pytorch import enas -from nni.nas.pytorch.callbacks import LearningRateScheduler, ArchitectureCheckpoint +from nni.nas.pytorch.callbacks import LRSchedulerCallback, ArchitectureCheckpoint from utils import accuracy, reward_accuracy logger = logging.getLogger() @@ -51,7 +51,7 @@ metrics=accuracy, reward_function=reward_accuracy, optimizer=optimizer, - callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")], + callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")], batch_size=args.batch_size, num_epochs=num_epochs, dataset_train=dataset_train, diff --git a/src/sdk/pynni/nni/nas/pytorch/base_mutator.py b/src/sdk/pynni/nni/nas/pytorch/base_mutator.py index 550e449dfc..df45a869af 100644 --- a/src/sdk/pynni/nni/nas/pytorch/base_mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/base_mutator.py @@ -51,21 +51,22 @@ def _parse_search_space(self, module, root=None, prefix="", memo=None, nested_de def mutables(self): return self._structured_mutables - @property def forward(self, *inputs): raise RuntimeError("Forward is undefined for mutators.") + def __setattr__(self, name, value): + if name == "model": + raise AttributeError("Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to " + "include you network, as it will include all parameters in model into the mutator.") + return super().__setattr__(name, value) + def enter_mutable_scope(self, mutable_scope): """ Callback when forward of a MutableScope is entered. Parameters ---------- - mutable_scope: MutableScope - - Returns - ------- - None + mutable_scope : MutableScope """ pass @@ -75,11 +76,7 @@ def exit_mutable_scope(self, mutable_scope): Parameters ---------- - mutable_scope: MutableScope - - Returns - ------- - None + mutable_scope : MutableScope """ pass @@ -89,8 +86,8 @@ def on_forward_layer_choice(self, mutable, *inputs): Parameters ---------- - mutable: LayerChoice - inputs: list of torch.Tensor + mutable : LayerChoice + inputs : list of torch.Tensor Returns ------- @@ -105,8 +102,8 @@ def on_forward_input_choice(self, mutable, tensor_list): Parameters ---------- - mutable: InputChoice - tensor_list: list of torch.Tensor + mutable : InputChoice + tensor_list : list of torch.Tensor Returns ------- diff --git a/src/sdk/pynni/nni/nas/pytorch/callbacks.py b/src/sdk/pynni/nni/nas/pytorch/callbacks.py index 83ae62cde0..817c1eb3f4 100644 --- a/src/sdk/pynni/nni/nas/pytorch/callbacks.py +++ b/src/sdk/pynni/nni/nas/pytorch/callbacks.py @@ -29,7 +29,7 @@ def on_batch_end(self, epoch): pass -class LearningRateScheduler(Callback): +class LRSchedulerCallback(Callback): def __init__(self, scheduler, mode="epoch"): super().__init__() assert mode == "epoch" diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py index 91d739c0a3..9674e2b954 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/mutator.py @@ -1,6 +1,6 @@ import torch -from torch import nn as nn -from torch.nn import functional as F +import torch.nn as nn +import torch.nn.functional as F from nni.nas.pytorch.mutator import Mutator from nni.nas.pytorch.mutables import LayerChoice, InputChoice diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py index 6392962111..772d455e72 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py @@ -2,27 +2,27 @@ import logging import torch -from torch import nn as nn - +import torch.nn as nn from nni.nas.pytorch.trainer import Trainer from nni.nas.pytorch.utils import AverageMeterGroup from .mutator import DartsMutator logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) class DartsTrainer(Trainer): def __init__(self, model, loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, - callbacks=None): + callbacks=None, arc_learning_rate=3.0E-4, unrolled=True): super().__init__(model, mutator if mutator is not None else DartsMutator(model), loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks) - self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), 3.0E-4, betas=(0.5, 0.999), + self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arc_learning_rate, betas=(0.5, 0.999), weight_decay=1.0E-3) + self.unrolled = unrolled + n_train = len(self.dataset_train) split = n_train // 2 indices = list(range(n_train)) @@ -43,42 +43,32 @@ def __init__(self, model, loss, metrics, def train_one_epoch(self, epoch): self.model.train() self.mutator.train() - lr = self.optimizer.param_groups[0]["lr"] meters = AverageMeterGroup() for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)): trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device) val_X, val_y = val_X.to(self.device), val_y.to(self.device) - # backup model for hessian - backup_model = copy.deepcopy(self.model.state_dict()) - # cannot deepcopy model because it will break the reference + # phase 1. architecture step + self.ctrl_optim.zero_grad() + if self.unrolled: + self._unrolled_backward(trn_X, trn_y, val_X, val_y) + else: + self._backward(val_X, val_y) + self.ctrl_optim.step() - # phase 1. child network step + # phase 2: child network step self.optimizer.zero_grad() - self.mutator.reset() - logits = self.model(trn_X) - loss = self.loss(logits, trn_y) + logits, loss = self._logits_and_loss(trn_X, trn_y) loss.backward() - # gradient clipping - nn.utils.clip_grad_norm_(self.model.parameters(), 5.) + nn.utils.clip_grad_norm_(self.model.parameters(), 5.) # gradient clipping self.optimizer.step() - new_model = copy.deepcopy(self.model.state_dict()) - - # phase 2. architect step (alpha) - self.ctrl_optim.zero_grad() - # compute unrolled loss - self._unrolled_backward(trn_X, trn_y, val_X, val_y, backup_model, lr) - self.ctrl_optim.step() - - self.model.load_state_dict(new_model) - metrics = self.metrics(logits, trn_y) metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: - logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch+1, - self.num_epochs, step+1, len(self.train_loader), meters) + logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, + self.num_epochs, step + 1, len(self.train_loader), meters) def validate_one_epoch(self, epoch): self.model.eval() @@ -92,55 +82,92 @@ def validate_one_epoch(self, epoch): metrics = self.metrics(logits, y) meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: - logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch+1, - self.num_epochs, step+1, len(self.test_loader), meters) + logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, + self.num_epochs, step + 1, len(self.test_loader), meters) + + def _logits_and_loss(self, X, y): + self.mutator.reset() + logits = self.model(X) + loss = self.loss(logits, y) + return logits, loss - def _unrolled_backward(self, trn_X, trn_y, val_X, val_y, backup_model, lr): + def _backward(self, val_X, val_y): + """ + Simple backward with gradient descent + """ + _, loss = self._logits_and_loss(val_X, val_y) + loss.backward() + + def _unrolled_backward(self, trn_X, trn_y, val_X, val_y): """ Compute unrolled loss and backward its gradients - Parameters - ---------- - v_model: backup model before this step - lr: learning rate for virtual gradient step (same as net lr) """ - self.mutator.reset() - loss = self.loss(self.model(val_X), val_y) - w_model = tuple(self.model.parameters()) - w_ctrl = tuple(self.mutator.parameters()) + backup_params = copy.deepcopy(tuple(self.model.parameters())) + + # do virtual step on training data + lr = self.optimizer.param_groups[0]["lr"] + momentum = self.optimizer.param_groups[0]["momentum"] + weight_decay = self.optimizer.param_groups[0]["weight_decay"] + self._compute_virtual_model(trn_X, trn_y, lr, momentum, weight_decay) + + # calculate unrolled loss on validation data + # keep gradients for model here for compute hessian + _, loss = self._logits_and_loss(val_X, val_y) + w_model, w_ctrl = tuple(self.model.parameters()), tuple(self.mutator.parameters()) w_grads = torch.autograd.grad(loss, w_model + w_ctrl) - d_model = w_grads[:len(w_model)] - d_ctrl = w_grads[len(w_model):] + d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):] - hessian = self._compute_hessian(backup_model, d_model, trn_X, trn_y) + # compute hessian and final gradients + hessian = self._compute_hessian(backup_params, d_model, trn_X, trn_y) with torch.no_grad(): for param, d, h in zip(w_ctrl, d_ctrl, hessian): + # gradient = dalpha - lr * hessian param.grad = d - lr * h - def _compute_hessian(self, model, dw, trn_X, trn_y): + # restore weights + self._restore_weights(backup_params) + + def _compute_virtual_model(self, X, y, lr, momentum, weight_decay): """ - dw = dw` { L_val(w`, alpha) } - w+ = w + eps * dw - w- = w - eps * dw - hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps) - eps = 0.01 / ||dw|| + Compute unrolled weights w` """ - self.model.load_state_dict(model) + # don't need zero_grad, using autograd to calculate gradients + _, loss = self._logits_and_loss(X, y) + gradients = torch.autograd.grad(loss, self.model.parameters()) + with torch.no_grad(): + for w, g in zip(self.model.parameters(), gradients): + m = self.optimizer.state[w].get("momentum_buffer", 0.) + w = w - lr * (momentum * m + g + weight_decay * w) + + def _restore_weights(self, backup_params): + with torch.no_grad(): + for param, backup in zip(self.model.parameters(), backup_params): + param.copy_(backup) + def _compute_hessian(self, backup_params, dw, trn_X, trn_y): + """ + dw = dw` { L_val(w`, alpha) } + w+ = w + eps * dw + w- = w - eps * dw + hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps) + eps = 0.01 / ||dw|| + """ + self._restore_weights(backup_params) norm = torch.cat([w.view(-1) for w in dw]).norm() eps = 0.01 / norm + if norm < 1E-8: + logger.warning("In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.", norm.item()) + dalphas = [] for e in [eps, -2. * eps]: # w+ = w + eps*dw`, w- = w - eps*dw` with torch.no_grad(): for p, d in zip(self.model.parameters(), dw): - p += eps * d + p += e * d - self.mutator.reset() - loss = self.loss(self.model(trn_X), trn_y) - if e > 0: - dalpha_pos = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w+) } - elif e < 0: - dalpha_neg = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w-) } + _, loss = self._logits_and_loss(trn_X, trn_y) + dalphas.append(torch.autograd.grad(loss, self.mutator.parameters())) + dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) } hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)] return hessian diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py index 9d9a176352..7a1a6f80af 100644 --- a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py @@ -25,6 +25,7 @@ def forward(self, inputs, hidden): class EnasMutator(Mutator): + def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False, skip_target=0.4, branch_bias=0.25): super().__init__(model) @@ -51,7 +52,7 @@ def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, ce self.max_layer_choice = mutable.length assert self.max_layer_choice == mutable.length, \ "ENAS mutator requires all layer choice have the same number of candidates." - # NOTE(yuge): We might implement an interface later. Judging by key now. + # We are judging by keys and module types to add biases to layer choices. Needs refactor. if "reduce" in mutable.key: def is_conv(choice): return "conv" in str(type(choice)).lower() diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py b/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py index 49052d6b08..d2074b94bc 100644 --- a/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py @@ -6,9 +6,7 @@ from nni.nas.pytorch.utils import AverageMeterGroup from .mutator import EnasMutator - logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) class EnasTrainer(Trainer): @@ -75,8 +73,8 @@ def train_one_epoch(self, epoch): meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: - logger.info("Model Epoch [%s/%s] Step [%s/%s] %s", epoch, - self.num_epochs, step, len(self.train_loader), meters) + logger.info("Model Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, + self.num_epochs, step + 1, len(self.train_loader), meters) # Train sampler (mutator) self.model.eval() @@ -114,8 +112,8 @@ def train_one_epoch(self, epoch): self.mutator_optim.zero_grad() if self.log_frequency is not None and step % self.log_frequency == 0: - logger.info("RL Epoch [%s/%s] Step [%s/%s] %s", epoch, self.num_epochs, - mutator_step // self.mutator_steps_aggregate, self.mutator_steps, meters) + logger.info("RL Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs, + mutator_step // self.mutator_steps_aggregate + 1, self.mutator_steps, meters) mutator_step += 1 if mutator_step >= total_mutator_steps: break diff --git a/src/sdk/pynni/nni/nas/pytorch/fixed.py b/src/sdk/pynni/nni/nas/pytorch/fixed.py index 6b83aa0800..d953e1cd5e 100644 --- a/src/sdk/pynni/nni/nas/pytorch/fixed.py +++ b/src/sdk/pynni/nni/nas/pytorch/fixed.py @@ -14,11 +14,11 @@ def __init__(self, model, fixed_arc, strict=True): Parameters ---------- - model: nn.Module + model : nn.Module A mutable network. - fixed_arc: str or dict + fixed_arc : str or dict Path to the architecture checkpoint (a string), or preloaded architecture object (a dict). - strict: bool + strict : bool Force everything that appears in `fixed_arc` to be used at least once. """ super().__init__(model) @@ -55,11 +55,11 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None): Parameters ---------- - model: torch.nn.Module + model : torch.nn.Module Model with mutables. - fixed_arc_path: str + fixed_arc_path : str Path to the JSON that stores the architecture. - device: torch.device + device : torch.device Architecture weights will be transfered to `device`. Returns @@ -76,3 +76,4 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None): architecture = FixedArchitecture(model, fixed_arc) architecture.to(device) architecture.reset() + return architecture diff --git a/src/sdk/pynni/nni/nas/pytorch/mutables.py b/src/sdk/pynni/nni/nas/pytorch/mutables.py index 4dbf514af8..115578a4ae 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutables.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutables.py @@ -39,6 +39,9 @@ def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs) def set_mutator(self, mutator): + if "mutator" in self.__dict__: + raise RuntimeError("`set_mutator` is called more than once. Did you parse the search space multiple times? " + "Or did you apply multiple fixed architectures?") self.__dict__["mutator"] = mutator def forward(self, *inputs): @@ -68,9 +71,10 @@ def __repr__(self): class MutableScope(Mutable): """ - Mutable scope labels a subgraph/submodule to help mutators make better decisions. + Mutable scope marks a subgraph/submodule to help mutators make better decisions. Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch corresponding events, and do status dump or update. + MutableScope are also mutables that are listed in the mutables (search space). """ def __init__(self, key): @@ -86,7 +90,7 @@ def __call__(self, *args, **kwargs): class LayerChoice(Mutable): - def __init__(self, op_candidates, reduction="mean", return_mask=False, key=None): + def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None): super().__init__(key=key) self.length = len(op_candidates) self.choices = nn.ModuleList(op_candidates) @@ -117,25 +121,25 @@ class InputChoice(Mutable): NO_KEY = "" def __init__(self, n_candidates=None, choose_from=None, n_chosen=None, - reduction="mean", return_mask=False, key=None): + reduction="sum", return_mask=False, key=None): """ Initialization. Parameters ---------- - n_candidates: int + n_candidates : int Number of inputs to choose from. - choose_from: list of str + choose_from : list of str List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled. If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates` number of empty string. - n_chosen: int + n_chosen : int Recommended inputs to choose. If None, mutator is instructed to select any. - reduction: str + reduction : str `mean`, `concat`, `sum` or `none`. - return_mask: bool + return_mask : bool If `return_mask`, return output tensor and a mask. Otherwise return tensor only. - key: str + key : str Key of the input choice. """ super().__init__(key=key) @@ -163,7 +167,7 @@ def forward(self, optional_inputs): Parameters ---------- - optional_inputs: list or dict + optional_inputs : list or dict Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of `choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as `choose_from`. diff --git a/src/sdk/pynni/nni/nas/pytorch/mutator.py b/src/sdk/pynni/nni/nas/pytorch/mutator.py index 80608c6925..a0a22b2649 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutator.py @@ -1,7 +1,11 @@ +import logging + import torch from nni.nas.pytorch.base_mutator import BaseMutator +logger = logging.getLogger(__name__) + class Mutator(BaseMutator): @@ -60,8 +64,8 @@ def on_forward_layer_choice(self, mutable, *inputs): Parameters ---------- - mutable: LayerChoice - inputs: list of torch.Tensor + mutable : LayerChoice + inputs : list of torch.Tensor Returns ------- @@ -85,9 +89,9 @@ def on_forward_input_choice(self, mutable, tensor_list): Parameters ---------- - mutable: InputChoice - tensor_list: list of torch.Tensor - tags: list of string + mutable : InputChoice + tensor_list : list of torch.Tensor + tags : list of string Returns ------- @@ -108,7 +112,7 @@ def _select_with_mask(self, map_fn, candidates, mask): return out def _tensor_reduction(self, reduction_type, tensor_list): - if tensor_list == "none": + if reduction_type == "none": return tensor_list if not tensor_list: return None # empty. return None for now @@ -129,12 +133,14 @@ def _get_decision(self, mutable): Parameters ---------- - mutable: Mutable + mutable : Mutable Returns ------- - any + object """ if mutable.key not in self._cache: raise ValueError("\"{}\" not found in decision cache.".format(mutable.key)) - return self._cache[mutable.key] + result = self._cache[mutable.key] + logger.debug("Decision %s: %s", mutable.key, result) + return result diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py index 5862e9714b..8787b7ae40 100644 --- a/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py @@ -4,7 +4,7 @@ import copy import numpy as np -from torch.nn import functional as F +import torch.nn.functional as F from nni.nas.pytorch.darts import DartsMutator from nni.nas.pytorch.mutables import LayerChoice diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py index af31da08fc..850b79a413 100644 --- a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import logging -from nni.nas.pytorch.callbacks import LearningRateScheduler +from nni.nas.pytorch.callbacks import LRSchedulerCallback from nni.nas.pytorch.darts import DartsTrainer from nni.nas.pytorch.trainer import BaseTrainer @@ -50,7 +50,7 @@ def train(self): darts_callbacks = [] if lr_scheduler is not None: - darts_callbacks.append(LearningRateScheduler(lr_scheduler)) + darts_callbacks.append(LRSchedulerCallback(lr_scheduler)) self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim, callbacks=darts_callbacks, **self.darts_parameters) diff --git a/src/sdk/pynni/nni/nas/pytorch/trainer.py b/src/sdk/pynni/nni/nas/pytorch/trainer.py index 9195631a60..879699f488 100644 --- a/src/sdk/pynni/nni/nas/pytorch/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/trainer.py @@ -24,6 +24,40 @@ def default(self, o): # pylint: disable=method-hidden class Trainer(BaseTrainer): def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks): + """ + Trainer initialization. + + Parameters + ---------- + model : nn.Module + Model with mutables. + mutator : BaseMutator + A mutator object that has been initialized with the model. + loss : callable + Called with logits and targets. Returns a loss tensor. + metrics : callable + Returns a dict that maps metrics keys to metrics data. + optimizer : Optimizer + Optimizer that optimizes the model. + num_epochs : int + Number of epochs of training. + dataset_train : torch.utils.data.Dataset + Dataset of training. + dataset_valid : torch.utils.data.Dataset + Dataset of validation/testing. + batch_size : int + Batch size. + workers : int + Number of workers used in data preprocessing. + device : torch.device + Device object. Either `torch.device("cuda")` or torch.device("cpu")`. When `None`, trainer will + automatic detects GPU and selects GPU first. + log_frequency : int + Number of mini-batches to log metrics. + callbacks : list of Callback + Callbacks to plug into the trainer. See Callbacks. + """ + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device self.model = model self.mutator = mutator diff --git a/src/sdk/pynni/nni/nas/pytorch/utils.py b/src/sdk/pynni/nni/nas/pytorch/utils.py index d3a4292155..b6766b07cf 100644 --- a/src/sdk/pynni/nni/nas/pytorch/utils.py +++ b/src/sdk/pynni/nni/nas/pytorch/utils.py @@ -28,6 +28,16 @@ class AverageMeter: """Computes and stores the average and current value""" def __init__(self, name, fmt=':f'): + """ + Initialization of AverageMeter + + Parameters + ---------- + name : str + Name to display. + fmt : str + Format string to print the values. + """ self.name = name self.fmt = fmt self.reset() @@ -78,12 +88,12 @@ def traverse(self, order="pre", deduplicate=True, memo=None): Parameters ---------- - order: str + order : str pre or post. If pre, current mutable is yield before children. Otherwise after. - deduplicate: bool + deduplicate : bool If true, mutables with the same key will not appear after the first appearance. - memo: dict - An auxiliary variable to make deduplicate happen. + memo : dict + An auxiliary dict that memorize keys seen before, so that deduplication is possible. Returns ------- From 17ea5f0a05dee5886a0f98a69a834f57fc41e147 Mon Sep 17 00:00:00 2001 From: squirrelsc Date: Mon, 25 Nov 2019 09:24:21 +0800 Subject: [PATCH 10/10] fix bug on exporting --- src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py index 850b79a413..edc25fa360 100644 --- a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import logging + from nni.nas.pytorch.callbacks import LRSchedulerCallback from nni.nas.pytorch.darts import DartsTrainer from nni.nas.pytorch.trainer import BaseTrainer @@ -66,8 +67,5 @@ def train(self): def validate(self): self.model.validate() - def export(self): - self.mutator.export() - def checkpoint(self): raise NotImplementedError("Not implemented yet")