This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
Copy pathtest_oneshot.py
151 lines (127 loc) · 5.87 KB
/
test_oneshot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import pytest
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data.sampler import RandomSampler
from nni.retiarii.evaluator.pytorch.lightning import Classification, DataLoader
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice
from nni.retiarii.oneshot.pytorch import (ConcatenateTrainValDataLoader,
DartsModule, EnasModule, SNASModule,
InterleavedTrainValDataLoader,
ProxylessModule, RandomSampleModule)
class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size=3, groups=in_ch)
self.pointwise = nn.Conv2d(in_ch, out_ch, kernel_size=1)
def forward(self, x):
return self.pointwise(self.depthwise(x))
class Net(pl.LightningModule):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = LayerChoice([
nn.Conv2d(32, 64, 3, 1),
DepthwiseSeparableConv(32, 64)
])
self.dropout1 = nn.Dropout(.25)
self.dropout2 = nn.Dropout(0.5)
self.dropout_choice = InputChoice(2, 1)
self.fc = LayerChoice([
nn.Sequential(
nn.Linear(9216, 64),
nn.ReLU(),
nn.Linear(64, 10),
),
nn.Sequential(
nn.Linear(9216, 128),
nn.ReLU(),
nn.Linear(128, 10),
),
nn.Sequential(
nn.Linear(9216, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
])
self.rpfc = nn.Linear(10, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(self.conv2(x), 2)
x1 = torch.flatten(self.dropout1(x), 1)
x2 = torch.flatten(self.dropout2(x), 1)
x = self.dropout_choice([x1, x2])
x = self.fc(x)
x = self.rpfc(x)
output = F.log_softmax(x, dim=1)
return output
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def prepare_model_data():
base_model = Net()
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = MNIST('data/mnist', train = True, download=True, transform=transform)
train_random_sampler = RandomSampler(train_dataset, True, int(len(train_dataset) / 10))
train_loader = DataLoader(train_dataset, 64, sampler = train_random_sampler)
valid_dataset = MNIST('data/mnist', train = False, download=True, transform=transform)
valid_random_sampler = RandomSampler(valid_dataset, True, int(len(valid_dataset) / 10))
valid_loader = DataLoader(valid_dataset, 64, sampler = valid_random_sampler)
trainer_kwargs = {
'max_epochs' : 1
}
return base_model, train_loader, valid_loader, trainer_kwargs
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def test_darts():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data()
cls = Classification(train_dataloader=train_loader, val_dataloaders = valid_loader, **trainer_kwargs)
cls.module.set_model(base_model)
darts_model = DartsModule(cls.module)
para_loader = InterleavedTrainValDataLoader(cls.train_dataloader, cls.val_dataloaders)
cls.trainer.fit(darts_model, para_loader)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def test_proxyless():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data()
cls = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **trainer_kwargs)
cls.module.set_model(base_model)
proxyless_model = ProxylessModule(cls.module)
para_loader = InterleavedTrainValDataLoader(cls.train_dataloader, cls.val_dataloaders)
cls.trainer.fit(proxyless_model, para_loader)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def test_enas():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data()
cls = Classification(train_dataloader = train_loader, val_dataloaders=valid_loader, **trainer_kwargs)
cls.module.set_model(base_model)
enas_model = EnasModule(cls.module)
concat_loader = ConcatenateTrainValDataLoader(cls.train_dataloader, cls.val_dataloaders)
cls.trainer.fit(enas_model, concat_loader)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def test_random():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data()
cls = Classification(train_dataloader = train_loader, val_dataloaders=valid_loader , **trainer_kwargs)
cls.module.set_model(base_model)
random_model = RandomSampleModule(cls.module)
cls.trainer.fit(random_model, cls.train_dataloader, cls.val_dataloaders)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def test_snas():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data()
cls = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **trainer_kwargs)
cls.module.set_model(base_model)
proxyless_model = SNASModule(cls.module, 1, use_temp_anneal=True)
para_loader = InterleavedTrainValDataLoader(cls.train_dataloader, cls.val_dataloaders)
cls.trainer.fit(proxyless_model, para_loader)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--exp', type=str, default='all', metavar='E',
help='exp to run, default = all' )
args = parser.parse_args()
if args.exp == 'all':
test_darts()
test_proxyless()
test_enas()
test_random()
test_snas()
else:
globals()[f'test_{args.exp}']()