Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Update APIs and add preliminary support for ENAS macro space (#1714)
Browse files Browse the repository at this point in the history
* add enas macro

* refactor example directory structure

* update docstring
  • Loading branch information
Yuge Zhang authored Nov 8, 2019
1 parent e238d34 commit bb797e1
Show file tree
Hide file tree
Showing 17 changed files with 854 additions and 91 deletions.
1 change: 1 addition & 0 deletions examples/nas/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
data
63 changes: 3 additions & 60 deletions examples/nas/darts/main.py → examples/nas/darts/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from argparse import ArgumentParser

import datasets
import image_ops as ops
import nni.nas.pytorch as nas
import torch
import torch.nn as nn
from nni.nas.pytorch.darts import DartsTrainer

import ops
from nni.nas import pytorch as nas


class SearchCell(nn.Module):
Expand Down Expand Up @@ -142,57 +139,3 @@ def forward(self, x):
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)
return logits


def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)

_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]

correct = pred.eq(target.view(1, -1).expand_as(pred))

res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res


if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=4, type=int)
parser.add_argument("--nodes", default=2, type=int)
parser.add_argument("--batch-size", default=3, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()

dataset_train, dataset_valid = datasets.get_dataset("cifar10")

model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes)
criterion = nn.CrossEntropyLoss()

optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
n_epochs = 50
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, n_epochs, eta_min=0.001)

trainer = DartsTrainer(model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
model_optim=optim,
lr_scheduler=lr_scheduler,
num_epochs=50,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=args.batch_size,
log_frequency=args.log_frequency)
trainer.train()
trainer.finalize()

# augment step
# ...
File renamed without changes.
43 changes: 43 additions & 0 deletions examples/nas/darts/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from argparse import ArgumentParser

import datasets
import torch
import torch.nn as nn

from model import SearchCNN
from nni.nas.pytorch.darts import DartsTrainer
from utils import accuracy


if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=4, type=int)
parser.add_argument("--nodes", default=2, type=int)
parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()

dataset_train, dataset_valid = datasets.get_dataset("cifar10")

model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes)
criterion = nn.CrossEntropyLoss()

optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
n_epochs = 50
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, n_epochs, eta_min=0.001)

trainer = DartsTrainer(model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
model_optim=optim,
lr_scheduler=lr_scheduler,
num_epochs=50,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=args.batch_size,
log_frequency=args.log_frequency)
trainer.train()
trainer.export()

# augment step
# ...
18 changes: 18 additions & 0 deletions examples/nas/darts/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)

_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]

correct = pred.eq(target.view(1, -1).expand_as(pred))

res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
25 changes: 25 additions & 0 deletions examples/nas/enas/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from torchvision import transforms
from torchvision.datasets import CIFAR10


def get_dataset(cls):
MEAN = [0.49139968, 0.48215827, 0.44653124]
STD = [0.24703233, 0.24348505, 0.26158768]
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
normalize = [
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
]

train_transform = transforms.Compose(transf + normalize)
valid_transform = transforms.Compose(normalize)

if cls == "cifar10":
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform)
else:
raise NotImplementedError
return dataset_train, dataset_valid
80 changes: 80 additions & 0 deletions examples/nas/enas/enas_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
import torch.nn as nn


class StdConv(nn.Module):
def __init__(self, C_in, C_out):
super(StdConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)

def forward(self, x):
return self.conv(x)


class PoolBranch(nn.Module):
def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False):
super().__init__()
self.preproc = StdConv(C_in, C_out)
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C_out, affine=affine)

def forward(self, x):
out = self.preproc(x)
out = self.pool(out)
out = self.bn(out)
return out


class SeparableConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding):
super(SeparableConv, self).__init__()
self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride,
groups=C_in, bias=False)
self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False)

def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out


class ConvBranch(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, separable):
super(ConvBranch, self).__init__()
self.preproc = StdConv(C_in, C_out)
if separable:
self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding)
else:
self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding)
self.postproc = nn.Sequential(
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)

def forward(self, x):
out = self.preproc(x)
out = self.conv(out)
out = self.postproc(out)
return out


class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=False):
super().__init__()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)

def forward(self, x):
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
142 changes: 142 additions & 0 deletions examples/nas/enas/macro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from argparse import ArgumentParser
import torch
import torch.nn as nn

import datasets
from ops import FactorizedReduce, ConvBranch, PoolBranch
from nni.nas.pytorch import mutables, enas


class ENASLayer(nn.Module):

def __init__(self, layer_id, in_filters, out_filters):
super().__init__()
self.in_filters = in_filters
self.out_filters = out_filters
self.mutable = mutables.LayerChoice([
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False),
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True),
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False),
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True),
PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
PoolBranch('max', in_filters, out_filters, 3, 1, 1)
])
if layer_id > 0:
self.skipconnect = mutables.InputChoice(layer_id, n_selected=None, reduction="sum")
else:
self.skipconnect = None
self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
self.mutable_scope = mutables.MutableScope("layer_{}".format(layer_id))

def forward(self, prev_layers):
with self.mutable_scope:
out = self.mutable(prev_layers[-1])
if self.skipconnect is not None:
connection = self.skipconnect(prev_layers[:-1],
["layer_{}".format(i) for i in range(len(prev_layers) - 1)])
if connection is not None:
out += connection
return self.batch_norm(out)


class GeneralNetwork(nn.Module):
def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10,
dropout_rate=0.0):
super().__init__()
self.num_layers = num_layers
self.num_classes = num_classes
self.out_filters = out_filters

self.stem = nn.Sequential(
nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_filters)
)

pool_distance = self.num_layers // 3
self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1]
self.dropout_rate = dropout_rate
self.dropout = nn.Dropout(self.dropout_rate)

self.layers = nn.ModuleList()
self.pool_layers = nn.ModuleList()
for layer_id in range(self.num_layers):
if layer_id in self.pool_layers_idx:
self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters))
self.layers.append(ENASLayer(layer_id, self.out_filters, self.out_filters))

self.gap = nn.AdaptiveAvgPool2d(1)
self.dense = nn.Linear(self.out_filters, self.num_classes)

def forward(self, x):
bs = x.size(0)
cur = self.stem(x)

layers = [cur]

for layer_id in range(self.num_layers):
cur = self.layers[layer_id](layers)
layers.append(cur)
if layer_id in self.pool_layers_idx:
for i, layer in enumerate(layers):
layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer)
cur = layers[-1]

cur = self.gap(cur).view(bs, -1)
cur = self.dropout(cur)
logits = self.dense(cur)
return logits


def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)

_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]

correct = pred.eq(target.view(1, -1).expand_as(pred))

res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res


def reward_accuracy(output, target, topk=(1,)):
batch_size = target.size(0)
_, predicted = torch.max(output.data, 1)
return (predicted == target).sum().item() / batch_size


if __name__ == "__main__":
parser = ArgumentParser("enas")
parser.add_argument("--batch-size", default=3, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()

dataset_train, dataset_valid = datasets.get_dataset("cifar10")

model = GeneralNetwork()
criterion = nn.CrossEntropyLoss()

n_epochs = 310
optim = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs, eta_min=0.001)

trainer = enas.EnasTrainer(model,
loss=criterion,
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optim,
lr_scheduler=lr_scheduler,
batch_size=args.batch_size,
num_epochs=n_epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
log_frequency=args.log_frequency)
trainer.train()
Loading

0 comments on commit bb797e1

Please sign in to comment.