This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update APIs and add preliminary support for ENAS macro space (#1714)
* add enas macro * refactor example directory structure * update docstring
- Loading branch information
Yuge Zhang
authored
Nov 8, 2019
1 parent
e238d34
commit bb797e1
Showing
17 changed files
with
854 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from argparse import ArgumentParser | ||
|
||
import datasets | ||
import torch | ||
import torch.nn as nn | ||
|
||
from model import SearchCNN | ||
from nni.nas.pytorch.darts import DartsTrainer | ||
from utils import accuracy | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser("darts") | ||
parser.add_argument("--layers", default=4, type=int) | ||
parser.add_argument("--nodes", default=2, type=int) | ||
parser.add_argument("--batch-size", default=128, type=int) | ||
parser.add_argument("--log-frequency", default=1, type=int) | ||
args = parser.parse_args() | ||
|
||
dataset_train, dataset_valid = datasets.get_dataset("cifar10") | ||
|
||
model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes) | ||
criterion = nn.CrossEntropyLoss() | ||
|
||
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) | ||
n_epochs = 50 | ||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, n_epochs, eta_min=0.001) | ||
|
||
trainer = DartsTrainer(model, | ||
loss=criterion, | ||
metrics=lambda output, target: accuracy(output, target, topk=(1,)), | ||
model_optim=optim, | ||
lr_scheduler=lr_scheduler, | ||
num_epochs=50, | ||
dataset_train=dataset_train, | ||
dataset_valid=dataset_valid, | ||
batch_size=args.batch_size, | ||
log_frequency=args.log_frequency) | ||
trainer.train() | ||
trainer.export() | ||
|
||
# augment step | ||
# ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
def accuracy(output, target, topk=(1,)): | ||
""" Computes the precision@k for the specified values of k """ | ||
maxk = max(topk) | ||
batch_size = target.size(0) | ||
|
||
_, pred = output.topk(maxk, 1, True, True) | ||
pred = pred.t() | ||
# one-hot case | ||
if target.ndimension() > 1: | ||
target = target.max(1)[1] | ||
|
||
correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
|
||
res = dict() | ||
for k in topk: | ||
correct_k = correct[:k].view(-1).float().sum(0) | ||
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from torchvision import transforms | ||
from torchvision.datasets import CIFAR10 | ||
|
||
|
||
def get_dataset(cls): | ||
MEAN = [0.49139968, 0.48215827, 0.44653124] | ||
STD = [0.24703233, 0.24348505, 0.26158768] | ||
transf = [ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip() | ||
] | ||
normalize = [ | ||
transforms.ToTensor(), | ||
transforms.Normalize(MEAN, STD) | ||
] | ||
|
||
train_transform = transforms.Compose(transf + normalize) | ||
valid_transform = transforms.Compose(normalize) | ||
|
||
if cls == "cifar10": | ||
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform) | ||
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform) | ||
else: | ||
raise NotImplementedError | ||
return dataset_train, dataset_valid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class StdConv(nn.Module): | ||
def __init__(self, C_in, C_out): | ||
super(StdConv, self).__init__() | ||
self.conv = nn.Sequential( | ||
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), | ||
nn.BatchNorm2d(C_out, affine=False), | ||
nn.ReLU() | ||
) | ||
|
||
def forward(self, x): | ||
return self.conv(x) | ||
|
||
|
||
class PoolBranch(nn.Module): | ||
def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False): | ||
super().__init__() | ||
self.preproc = StdConv(C_in, C_out) | ||
if pool_type.lower() == 'max': | ||
self.pool = nn.MaxPool2d(kernel_size, stride, padding) | ||
elif pool_type.lower() == 'avg': | ||
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) | ||
else: | ||
raise ValueError() | ||
self.bn = nn.BatchNorm2d(C_out, affine=affine) | ||
|
||
def forward(self, x): | ||
out = self.preproc(x) | ||
out = self.pool(out) | ||
out = self.bn(out) | ||
return out | ||
|
||
|
||
class SeparableConv(nn.Module): | ||
def __init__(self, C_in, C_out, kernel_size, stride, padding): | ||
super(SeparableConv, self).__init__() | ||
self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride, | ||
groups=C_in, bias=False) | ||
self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False) | ||
|
||
def forward(self, x): | ||
out = self.depthwise(x) | ||
out = self.pointwise(out) | ||
return out | ||
|
||
|
||
class ConvBranch(nn.Module): | ||
def __init__(self, C_in, C_out, kernel_size, stride, padding, separable): | ||
super(ConvBranch, self).__init__() | ||
self.preproc = StdConv(C_in, C_out) | ||
if separable: | ||
self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding) | ||
else: | ||
self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding) | ||
self.postproc = nn.Sequential( | ||
nn.BatchNorm2d(C_out, affine=False), | ||
nn.ReLU() | ||
) | ||
|
||
def forward(self, x): | ||
out = self.preproc(x) | ||
out = self.conv(out) | ||
out = self.postproc(out) | ||
return out | ||
|
||
|
||
class FactorizedReduce(nn.Module): | ||
def __init__(self, C_in, C_out, affine=False): | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) | ||
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) | ||
self.bn = nn.BatchNorm2d(C_out, affine=affine) | ||
|
||
def forward(self, x): | ||
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) | ||
out = self.bn(out) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from argparse import ArgumentParser | ||
import torch | ||
import torch.nn as nn | ||
|
||
import datasets | ||
from ops import FactorizedReduce, ConvBranch, PoolBranch | ||
from nni.nas.pytorch import mutables, enas | ||
|
||
|
||
class ENASLayer(nn.Module): | ||
|
||
def __init__(self, layer_id, in_filters, out_filters): | ||
super().__init__() | ||
self.in_filters = in_filters | ||
self.out_filters = out_filters | ||
self.mutable = mutables.LayerChoice([ | ||
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False), | ||
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True), | ||
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False), | ||
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True), | ||
PoolBranch('avg', in_filters, out_filters, 3, 1, 1), | ||
PoolBranch('max', in_filters, out_filters, 3, 1, 1) | ||
]) | ||
if layer_id > 0: | ||
self.skipconnect = mutables.InputChoice(layer_id, n_selected=None, reduction="sum") | ||
else: | ||
self.skipconnect = None | ||
self.batch_norm = nn.BatchNorm2d(out_filters, affine=False) | ||
self.mutable_scope = mutables.MutableScope("layer_{}".format(layer_id)) | ||
|
||
def forward(self, prev_layers): | ||
with self.mutable_scope: | ||
out = self.mutable(prev_layers[-1]) | ||
if self.skipconnect is not None: | ||
connection = self.skipconnect(prev_layers[:-1], | ||
["layer_{}".format(i) for i in range(len(prev_layers) - 1)]) | ||
if connection is not None: | ||
out += connection | ||
return self.batch_norm(out) | ||
|
||
|
||
class GeneralNetwork(nn.Module): | ||
def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10, | ||
dropout_rate=0.0): | ||
super().__init__() | ||
self.num_layers = num_layers | ||
self.num_classes = num_classes | ||
self.out_filters = out_filters | ||
|
||
self.stem = nn.Sequential( | ||
nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False), | ||
nn.BatchNorm2d(out_filters) | ||
) | ||
|
||
pool_distance = self.num_layers // 3 | ||
self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1] | ||
self.dropout_rate = dropout_rate | ||
self.dropout = nn.Dropout(self.dropout_rate) | ||
|
||
self.layers = nn.ModuleList() | ||
self.pool_layers = nn.ModuleList() | ||
for layer_id in range(self.num_layers): | ||
if layer_id in self.pool_layers_idx: | ||
self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters)) | ||
self.layers.append(ENASLayer(layer_id, self.out_filters, self.out_filters)) | ||
|
||
self.gap = nn.AdaptiveAvgPool2d(1) | ||
self.dense = nn.Linear(self.out_filters, self.num_classes) | ||
|
||
def forward(self, x): | ||
bs = x.size(0) | ||
cur = self.stem(x) | ||
|
||
layers = [cur] | ||
|
||
for layer_id in range(self.num_layers): | ||
cur = self.layers[layer_id](layers) | ||
layers.append(cur) | ||
if layer_id in self.pool_layers_idx: | ||
for i, layer in enumerate(layers): | ||
layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer) | ||
cur = layers[-1] | ||
|
||
cur = self.gap(cur).view(bs, -1) | ||
cur = self.dropout(cur) | ||
logits = self.dense(cur) | ||
return logits | ||
|
||
|
||
def accuracy(output, target, topk=(1,)): | ||
""" Computes the precision@k for the specified values of k """ | ||
maxk = max(topk) | ||
batch_size = target.size(0) | ||
|
||
_, pred = output.topk(maxk, 1, True, True) | ||
pred = pred.t() | ||
# one-hot case | ||
if target.ndimension() > 1: | ||
target = target.max(1)[1] | ||
|
||
correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
|
||
res = dict() | ||
for k in topk: | ||
correct_k = correct[:k].view(-1).float().sum(0) | ||
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() | ||
return res | ||
|
||
|
||
def reward_accuracy(output, target, topk=(1,)): | ||
batch_size = target.size(0) | ||
_, predicted = torch.max(output.data, 1) | ||
return (predicted == target).sum().item() / batch_size | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser("enas") | ||
parser.add_argument("--batch-size", default=3, type=int) | ||
parser.add_argument("--log-frequency", default=1, type=int) | ||
args = parser.parse_args() | ||
|
||
dataset_train, dataset_valid = datasets.get_dataset("cifar10") | ||
|
||
model = GeneralNetwork() | ||
criterion = nn.CrossEntropyLoss() | ||
|
||
n_epochs = 310 | ||
optim = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) | ||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs, eta_min=0.001) | ||
|
||
trainer = enas.EnasTrainer(model, | ||
loss=criterion, | ||
metrics=accuracy, | ||
reward_function=reward_accuracy, | ||
optimizer=optim, | ||
lr_scheduler=lr_scheduler, | ||
batch_size=args.batch_size, | ||
num_epochs=n_epochs, | ||
dataset_train=dataset_train, | ||
dataset_valid=dataset_valid, | ||
log_frequency=args.log_frequency) | ||
trainer.train() |
Oops, something went wrong.