-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathclient.py
72 lines (61 loc) · 2.65 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import copy
import torch
from torch import optim, nn
from collections import defaultdict
from torch.utils.data import DataLoader
from utils.utils import HardNegativeMining, MeanReduction
class Client:
def __init__(self, args, dataset, model, test_client=False):
self.args = args
self.dataset = dataset
self.name = self.dataset.client_name
self.model = model
self.train_loader = DataLoader(self.dataset, batch_size=self.args.bs, shuffle=True, drop_last=True) \
if not test_client else None
self.test_loader = DataLoader(self.dataset, batch_size=1, shuffle=False)
self.criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='none')
self.reduction = HardNegativeMining() if self.args.hnm else MeanReduction()
def __str__(self):
return self.name
@staticmethod
def update_metric(metric, outputs, labels):
_, prediction = outputs.max(dim=1)
labels = labels.cpu().numpy()
prediction = prediction.cpu().numpy()
metric.update(labels, prediction)
def _get_outputs(self, images):
if self.args.model == 'deeplabv3_mobilenetv2':
return self.model(images)['out']
if self.args.model == 'resnet18':
return self.model(images)
raise NotImplementedError
def run_epoch(self, cur_epoch, optimizer):
"""
This method locally trains the model with the dataset of the client. It handles the training at mini-batch level
:param cur_epoch: current epoch of training
:param optimizer: optimizer used for the local training
"""
for cur_step, (images, labels) in enumerate(self.train_loader):
# TODO: missing code here!
raise NotImplementedError
def train(self):
"""
This method locally trains the model with the dataset of the client. It handles the training at epochs level
(by calling the run_epoch method for each local epoch of training)
:return: length of the local dataset, copy of the model parameters
"""
# TODO: missing code here!
for epoch in range(self.args.num_epochs):
# TODO: missing code here!
raise NotImplementedError
def test(self, metric):
"""
This method tests the model on the local dataset of the client.
:param metric: StreamMetric object
"""
# TODO: missing code here!
with torch.no_grad():
for i, (images, labels) in enumerate(self.test_loader):
# TODO: missing code here!
raise NotImplementedError
self.update_metric(metric, outputs, labels)