Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] Weight-sharing trainers #3137

Merged
merged 9 commits into from
Dec 5, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions examples/nas/darts/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import datasets
from model import CNN
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
from nni.algorithms.nas.pytorch.darts import DartsTrainer
from utils import accuracy


logger = logging.getLogger('nni')

if __name__ == "__main__":
Expand All @@ -25,6 +25,7 @@
parser.add_argument("--channels", default=16, type=int)
parser.add_argument("--unrolled", default=False, action="store_true")
parser.add_argument("--visualization", default=False, action="store_true")
parser.add_argument("--v1", default=False, action="store_true")
args = parser.parse_args()

dataset_train, dataset_valid = datasets.get_dataset("cifar10")
Expand All @@ -35,17 +36,35 @@
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)

trainer = DartsTrainer(model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
optimizer=optim,
num_epochs=args.epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=args.batch_size,
log_frequency=args.log_frequency,
unrolled=args.unrolled,
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
if args.visualization:
trainer.enable_visualization()
trainer.train()
if args.v1:
from nni.algorithms.nas.pytorch.darts import DartsTrainer
trainer = DartsTrainer(model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
optimizer=optim,
num_epochs=args.epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=args.batch_size,
log_frequency=args.log_frequency,
unrolled=args.unrolled,
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
if args.visualization:
trainer.enable_visualization()

trainer.train()
else:
from nni.retiarii.trainer.pytorch import DartsTrainer
trainer = DartsTrainer(
model=model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
optimizer=optim,
num_epochs=args.epochs,
dataset=dataset_train,
batch_size=args.batch_size,
log_frequency=args.log_frequency,
unrolled=args.unrolled
)
trainer.fit()
print('Final architecture:', trainer.export())
20 changes: 14 additions & 6 deletions examples/nas/enas/micro.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,17 @@ def __init__(self, cell_name, prev_labels, channels):
], key=cell_name + "_op")

def forward(self, prev_layers):
chosen_input, chosen_mask = self.input_choice(prev_layers)
cell_out = self.op_choice(chosen_input)
return cell_out, chosen_mask
from nni.retiarii.trainer.pytorch.random import PathSamplingInputChoice
out = self.input_choice(prev_layers)
if isinstance(self.input_choice, PathSamplingInputChoice):
# Retiarii pattern
sampled = self.input_choice.sampled
return out, torch.tensor([i == sampled or (isinstance(sampled, list) and i in sampled)
for i in range(len(self.input_choice))], dtype=torch.bool)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic is complicated, why?

else:
chosen_input, chosen_mask = out
cell_out = self.op_choice(chosen_input)
return cell_out, chosen_mask


class Node(mutables.MutableScope):
Expand All @@ -71,7 +79,7 @@ def __init__(self, in_channels, out_channels):
self.process = None
if in_channels != out_channels:
self.process = StdConv(in_channels, out_channels)

def forward(self, x):
if self.process is None:
return x
Expand All @@ -83,7 +91,7 @@ def __init__(self, in_channels_pp, in_channels_p, out_channels):
super().__init__()
self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False)
self.reduce1 = FactorizedReduce(in_channels_p, out_channels, affine=False)

def forward(self, pprev, prev):
return self.reduce0(pprev), self.reduce1(prev)

Expand All @@ -109,7 +117,7 @@ def reset_parameters(self):
nn.init.kaiming_normal_(self.final_conv_w)

def forward(self, pprev, prev):
pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev)
pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev)

prev_nodes_out = [pprev_, prev_]
nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device)
Expand Down
55 changes: 37 additions & 18 deletions examples/nas/enas/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,36 +26,55 @@
parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)")
parser.add_argument("--visualization", default=False, action="store_true")
parser.add_argument("--v1", default=False, action="store_true")
args = parser.parse_args()

dataset_train, dataset_valid = datasets.get_dataset("cifar10")
mutator = None
ctrl_kwargs = {}
if args.search_for == "macro":
model = GeneralNetwork()
num_epochs = args.epochs or 310
mutator = None
elif args.search_for == "micro":
model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True)
model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=False)
num_epochs = args.epochs or 150
mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True)
if args.v1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to directly abandon v1

mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True)
else:
ctrl_kwargs = {"tanh_constant": 1.1}
else:
raise AssertionError

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001)

trainer = enas.EnasTrainer(model,
loss=criterion,
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optimizer,
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
batch_size=args.batch_size,
num_epochs=num_epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
log_frequency=args.log_frequency,
mutator=mutator)
if args.visualization:
trainer.enable_visualization()
trainer.train()
if args.v1:
trainer = enas.EnasTrainer(model,
loss=criterion,
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optimizer,
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
batch_size=args.batch_size,
num_epochs=num_epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
log_frequency=args.log_frequency,
mutator=mutator)
if args.visualization:
trainer.enable_visualization()
trainer.train()
else:
from nni.retiarii.trainer.pytorch.enas import EnasTrainer
trainer = EnasTrainer(model,
loss=criterion,
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optimizer,
batch_size=args.batch_size,
num_epochs=num_epochs,
dataset=dataset_train,
log_frequency=args.log_frequency,
ctrl_kwargs=ctrl_kwargs)
trainer.fit()
2 changes: 1 addition & 1 deletion examples/nas/naive/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torchvision.transforms as transforms

from nni.nas.pytorch.mutables import LayerChoice, InputChoice
from nni.nas.pytorch.darts import DartsTrainer
from nni.algorithms.nas.pytorch.darts import DartsTrainer


class Net(nn.Module):
Expand Down
31 changes: 27 additions & 4 deletions examples/nas/proxylessnas/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import logging
import os
import sys
import logging
from argparse import ArgumentParser

import torch
import datasets
from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
from torchvision import transforms

from putils import get_parameters
import datasets
from model import SearchMobileNet
from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
from putils import LabelSmoothingLoss, accuracy, get_parameters
from retrain import Retrain

logger = logging.getLogger('nni_proxylessnas')
Expand All @@ -30,7 +33,7 @@
parser.add_argument("--resize_scale", default=0.08, type=float)
parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None'])
# configurations for training mode
parser.add_argument("--train_mode", default='search', type=str, choices=['search', 'retrain'])
parser.add_argument("--train_mode", default='search', type=str, choices=['search_v1', 'search', 'retrain'])
# configurations for search
parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str)
parser.add_argument("--arch_path", default='./arch_path.pt', type=str)
Expand Down Expand Up @@ -80,6 +83,26 @@
optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5)

if args.train_mode == 'search':
from nni.retiarii.trainer.pytorch import ProxylessTrainer
from torchvision.datasets import ImageNet
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = ImageNet(args.data_path, transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
trainer = ProxylessTrainer(model,
loss=LabelSmoothingLoss(),
dataset=dataset,
optimizer=optimizer,
metrics=lambda output, target: accuracy(output, target, topk=(1, 5,)),
num_epochs=120,
log_frequency=10)
trainer.fit()
print('Final architecture:', trainer.export())
elif args.train_mode == 'search_v1':
# this is architecture search
logger.info('Creating ProxylessNasTrainer...')
trainer = ProxylessNasTrainer(model,
Expand Down
2 changes: 0 additions & 2 deletions examples/nas/proxylessnas/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,9 @@ def __init__(self,
# if it is not the first one
op_candidates += [ops.OPS['Zero'](input_channel, width, stride)]
conv_op = nas.mutables.LayerChoice(op_candidates,
return_mask=True,
key="s{}_c{}".format(stage_cnt, i))
else:
conv_op = nas.mutables.LayerChoice(op_candidates,
return_mask=True,
key="s{}_c{}".format(stage_cnt, i))
# shortcut
if stride == 1 and input_channel == width:
Expand Down
20 changes: 7 additions & 13 deletions examples/nas/proxylessnas/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,13 @@ def __init__(self, mobile_inverted_conv, shortcut, op_candidates_list):
self.op_candidates_list = op_candidates_list

def forward(self, x):
out, idx = self.mobile_inverted_conv(x)
# TODO: unify idx format
if not isinstance(idx, int):
idx = (idx == 1).nonzero()
if self.op_candidates_list[idx].is_zero_layer():
res = x
elif self.shortcut is None:
res = out
else:
conv_x = out
skip_x = self.shortcut(x)
res = skip_x + conv_x
return res
out = self.mobile_inverted_conv(x)
if torch.sum(torch.abs(out)).item() == 0 and x.size() == out.size():
# is zero layer
return x
if self.shortcut is None:
return out
return out + self.shortcut(x)


class ShuffleLayer(nn.Module):
Expand Down
40 changes: 40 additions & 0 deletions examples/nas/proxylessnas/putils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
import torch.nn as nn


def get_parameters(model, keys=None, mode='include'):
if keys is None:
for name, param in model.named_parameters():
Expand Down Expand Up @@ -36,6 +38,7 @@ def get_same_padding(kernel_size):
assert kernel_size % 2 > 0, 'kernel size should be odd number'
return kernel_size // 2


def build_activation(act_func, inplace=True):
if act_func == 'relu':
return nn.ReLU(inplace=inplace)
Expand Down Expand Up @@ -65,3 +68,40 @@ def make_divisible(v, divisor, min_val=None):
if new_v < 0.9 * v:
new_v += divisor
return new_v


def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)

_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]

correct = pred.eq(target.view(1, -1).expand_as(pred))

res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res


class LabelSmoothingLoss(nn.Module):
def __init__(self, smoothing=0.1, dim=-1):
super(LabelSmoothingLoss, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.dim = dim

def forward(self, pred, target):
pred = pred.log_softmax(dim=self.dim)
num_classes = pred.size(self.dim)
with torch.no_grad():
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (num_classes - 1))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
6 changes: 6 additions & 0 deletions nni/nas/pytorch/mutables.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,13 @@ class MutableScope(Mutable):
def __init__(self, key):
super().__init__(key=key)

def _check_built(self):
return True # bypass the test because it's deprecated

def __call__(self, *args, **kwargs):
if not hasattr(self, 'mutator'):
return super().__call__(*args, **kwargs)
warnings.warn("`MutableScope` is deprecated in Retiarii.", DeprecationWarning)
try:
self._check_built()
self.mutator.enter_mutable_scope(self)
Expand Down
13 changes: 13 additions & 0 deletions nni/retiarii/trainer/interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
from typing import *


class BaseTrainer(abc.ABC):
Expand All @@ -20,3 +21,15 @@ class BaseTrainer(abc.ABC):
@abc.abstractmethod
def fit(self) -> None:
pass


class BaseOneShotTrainer(BaseTrainer):
"""
Build many (possibly all) architectures into a full graph, search (with train) and export the best.

It has an extra ``export`` function that exports an object representing the final searched architecture.
"""

@abc.abstractmethod
def export(self) -> Any:
pass
5 changes: 5 additions & 0 deletions nni/retiarii/trainer/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .base import PyTorchImageClassificationTrainer
from .darts import DartsTrainer
from .enas import EnasTrainer
from .proxyless import ProxylessTrainer
from .random import RandomTrainer, SinglePathTrainer
Loading