-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpso_trainer.py
124 lines (90 loc) · 4.66 KB
/
pso_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import argparse
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
from model import LeNet
from PSO import Swarm, PSO
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST, CIFAR10
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import TensorBoardLogger
class psoSystem(LightningModule):
def __init__(self, hparams):
super(psoSystem,self).__init__()
self.hparams = hparams
self.model = LeNet()
self.criterion = nn.CrossEntropyLoss()
self.swarm = Swarm(hparams.num_particles,self.criterion)
def configure_optimizers(self):
#l = [self.mask]
return PSO(self.swarm.particles,
self.hparams.cognitive_constant,
self.hparams.social_constant,
self.hparams.inertia)
def optimizer_step(self,current_epoch,batch_idx,optimizer,optimizer_idx,second_order_closure,on_tpu,using_native_amp,using_lbfgs):
optimizer(self.swarm.fitness_list)
def prepare_data(self):
transformss = transforms.Compose([transforms.ToTensor()])
if self.hparams.dataset == 'mnist':
self.train_dataset = MNIST(self.hparams.data_dir, train=True, download=True, transform=transformss)
self.test_datset = MNIST(self.hparams.data_dir, train=False, download=True, transform=transformss)
self.train_dataset, self.valid_dataset = torch.utils.data.random_split(self.train_dataset,[50000,10000])
if self.hparams.dataset == 'cifar10':
self.train_dataset = CIFAR10(self.hparams.data_dir, train=True, download=True, transform=transformss)
self.test_datset = CIFAR10(self.hparams.data_dir, train=False, download=True, transform=transformss)
self.train_dataset, self.valid_dataset = torch.utils.data.random_split(self.train_dataset,[40000,10000])
#self.len_train_datset = len(self.train_dataset)
def train_dataloader(self):
loader = DataLoader(self.train_dataset, batch_size=32, num_workers=4)
return loader
def val_dataloader(self):
loader = DataLoader(self.valid_dataset, batch_size=32, num_workers=4)
return loader
def test_dataloader(self):
loader = DataLoader(self.test_dataset, batch_size=32, num_workers=4)
return loader
def validation_step(self,batch,batch_idx):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat,y)
y_pred = torch.argmax(y_hat,dim=-1)
correct = torch.mean(1.0*(y_pred==y))
return {'val_loss':loss,'num_correct':correct,'batch_size':torch.Tensor(x.size(0))}
def test_step(self,batch,batch_idx):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat,y)
y_pred = torch.argmax(y_hat,dim=-1)
correct = torch.mean(1.0*(y_pred==y))
return {'test_loss':loss,'num_correct':correct,'batch_size':torch.Tensor(x.size(0))}
def training_step(self,batch,batch_idx):
self.swarm.evaulate(batch)
def validation_epoch_end(self,outputs):
val_loss = torch.mean(torch.stack([output['val_loss'] for output in outputs]))
acc = torch.mean(torch.stack([output['num_correct'] for output in outputs]))
batch_size = torch.sum(torch.cat([output['batch_size'] for output in outputs]))
#acc = acc/10000
#val_loss = val_loss/10000
return {'progress_bar':{'accuracy':acc, 'val_loss':val_loss},'log':{'accuracy':acc, 'val_loss':val_loss}}
def test_epoch_end(self,outputs):
test_loss = torch.mean(torch.stack([output['test_loss'] for output in outputs]))
test_acc = torch.mean(torch.stack([output['num_correct'] for output in outputs]))
batch_size = torch.sum(torch.cat([output['batch_size'] for output in outputs]))
#test_acc = test_acc/10000
#test_loss = test_loss/10000
return {'progress_bar':{'accuracy':test_acc, 'val_loss':test_loss},'log':{'accuracy':test_acc, 'val_loss':test_loss}}
def forward(self,x):
return self.model(x,self.mask)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = Trainer.add_argparse_args(parser)
parser.add_argument('--data_dir',type=str,help='dir to store train data')
parser.add_argument('--dataset',type=str,help='cifar10|mnist')
args = parser.parse_args()
logger = TensorBoardLogger('tb_logs', name='Cifar_run')
trainer = Trainer.from_argparse_args(args)
trainer.logger = logger
system = MaskGradSystem(args)
trainer.fit(system)