Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Increase SPOS accuracy (#2902)
Browse files Browse the repository at this point in the history
* 🚧 fix spos always affine=False bug; change workers from to 4 to avoid dali breaking out; add dump_checkpoint to save the result searched by scratch.py

* 🐛 fix bug of import os

Co-authored-by: limingyao <limingyao@ainirobot.com>
  • Loading branch information
CuriousCat-7 and limingyao authored Sep 20, 2020
1 parent 8c71813 commit 4cf67f5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
15 changes: 8 additions & 7 deletions examples/nas/spos/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class ShuffleNetBlock(nn.Module):
When stride = 1, the block receives input with 2 * inp channels. Otherwise inp channels.
"""

def __init__(self, inp, oup, mid_channels, ksize, stride, sequence="pdp"):
def __init__(self, inp, oup, mid_channels, ksize, stride, sequence="pdp", affine=True):
super().__init__()
assert stride in [1, 2]
assert ksize in [3, 5, 7]
Expand All @@ -22,6 +22,7 @@ def __init__(self, inp, oup, mid_channels, ksize, stride, sequence="pdp"):
self.stride = stride
self.pad = ksize // 2
self.oup_main = oup - self.channels
self._affine = affine
assert self.oup_main > 0

self.branch_main = nn.Sequential(*self._decode_point_depth_conv(sequence))
Expand All @@ -31,10 +32,10 @@ def __init__(self, inp, oup, mid_channels, ksize, stride, sequence="pdp"):
# dw
nn.Conv2d(self.channels, self.channels, ksize, stride, self.pad,
groups=self.channels, bias=False),
nn.BatchNorm2d(self.channels, affine=False),
nn.BatchNorm2d(self.channels, affine=affine),
# pw-linear
nn.Conv2d(self.channels, self.channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(self.channels, affine=False),
nn.BatchNorm2d(self.channels, affine=affine),
nn.ReLU(inplace=True)
)

Expand All @@ -61,12 +62,12 @@ def _decode_point_depth_conv(self, sequence):
assert pc == c, "Depth-wise conv must not change channels."
result.append(nn.Conv2d(pc, c, self.ksize, self.stride if first_depth else 1, self.pad,
groups=c, bias=False))
result.append(nn.BatchNorm2d(c, affine=False))
result.append(nn.BatchNorm2d(c, affine=self._affine))
first_depth = False
elif token == "p":
# point-wise conv
result.append(nn.Conv2d(pc, c, 1, 1, 0, bias=False))
result.append(nn.BatchNorm2d(c, affine=False))
result.append(nn.BatchNorm2d(c, affine=self._affine))
result.append(nn.ReLU(inplace=True))
first_point = False
else:
Expand All @@ -85,5 +86,5 @@ def _channel_shuffle(self, x):

class ShuffleXceptionBlock(ShuffleNetBlock):

def __init__(self, inp, oup, mid_channels, stride):
super().__init__(inp, oup, mid_channels, 3, stride, "dpdpdp")
def __init__(self, inp, oup, mid_channels, stride, affine=True):
super().__init__(inp, oup, mid_channels, 3, stride, "dpdpdp", affine)
15 changes: 8 additions & 7 deletions examples/nas/spos/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ShuffleNetV2OneShot(nn.Module):
]

def __init__(self, input_size=224, first_conv_channels=16, last_conv_channels=1024, n_classes=1000,
op_flops_path="./data/op_flops_dict.pkl"):
op_flops_path="./data/op_flops_dict.pkl", affine=False):
super().__init__()

assert input_size % 32 == 0
Expand All @@ -36,11 +36,12 @@ def __init__(self, input_size=224, first_conv_channels=16, last_conv_channels=10
self._first_conv_channels = first_conv_channels
self._last_conv_channels = last_conv_channels
self._n_classes = n_classes
self._affine = affine

# building first layer
self.first_conv = nn.Sequential(
nn.Conv2d(3, first_conv_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(first_conv_channels, affine=False),
nn.BatchNorm2d(first_conv_channels, affine=affine),
nn.ReLU(inplace=True),
)
self._feature_map_size //= 2
Expand All @@ -54,7 +55,7 @@ def __init__(self, input_size=224, first_conv_channels=16, last_conv_channels=10

self.conv_last = nn.Sequential(
nn.Conv2d(p_channels, last_conv_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(last_conv_channels, affine=False),
nn.BatchNorm2d(last_conv_channels, affine=affine),
nn.ReLU(inplace=True),
)
self.globalpool = nn.AvgPool2d(self._feature_map_size)
Expand All @@ -75,10 +76,10 @@ def _make_blocks(self, blocks, in_channels, channels):
base_mid_channels = channels // 2
mid_channels = int(base_mid_channels) # prepare for scale
choice_block = mutables.LayerChoice([
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=3, stride=stride),
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=5, stride=stride),
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=7, stride=stride),
ShuffleXceptionBlock(inp, oup, mid_channels=mid_channels, stride=stride)
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=3, stride=stride, affine=self._affine),
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=5, stride=stride, affine=self._affine),
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=7, stride=stride, affine=self._affine),
ShuffleXceptionBlock(inp, oup, mid_channels=mid_channels, stride=stride, affine=self._affine)
])
result.append(choice_block)

Expand Down
18 changes: 16 additions & 2 deletions examples/nas/spos/scratch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import argparse
import logging
import random
Expand Down Expand Up @@ -70,12 +71,24 @@ def validate(epoch, model, criterion, loader, writer, args):
logger.info("Epoch %d validation: top1 = %f, top5 = %f", epoch + 1, meters.acc1.avg, meters.acc5.avg)


def dump_checkpoint(model, epoch, checkpoint_dir):
if isinstance(model, nn.DataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
dest_path = os.path.join(checkpoint_dir, "epoch_{}.pth.tar".format(epoch))
logger.info("Saving model to %s", dest_path)
torch.save(state_dict, dest_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser("SPOS Training From Scratch")
parser.add_argument("--imagenet-dir", type=str, default="./data/imagenet")
parser.add_argument("--tb-dir", type=str, default="runs")
parser.add_argument("--architecture", type=str, default="architecture_final.json")
parser.add_argument("--workers", type=int, default=12)
parser.add_argument("--workers", type=int, default=4)
parser.add_argument("--batch-size", type=int, default=1024)
parser.add_argument("--epochs", type=int, default=240)
parser.add_argument("--learning-rate", type=float, default=0.5)
Expand All @@ -96,7 +109,7 @@ def validate(epoch, model, criterion, loader, writer, args):
random.seed(args.seed)
torch.backends.cudnn.deterministic = True

model = ShuffleNetV2OneShot()
model = ShuffleNetV2OneShot(affine=True)
model.cuda()
apply_fixed_architecture(model, args.architecture)
if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu
Expand Down Expand Up @@ -124,5 +137,6 @@ def validate(epoch, model, criterion, loader, writer, args):
train(epoch, model, criterion, optimizer, train_loader, writer, args)
validate(epoch, model, criterion, val_loader, writer, args)
scheduler.step()
dump_checkpoint(model, epoch, "scratch_checkpoints")

writer.close()

0 comments on commit 4cf67f5

Please sign in to comment.