-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
42 lines (34 loc) · 961 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import numpy as np
import random
class AvgMeter(object):
def __init__(self, num=40):
self.num = num
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.losses = []
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
self.losses.append(val)
def show(self):
a = len(self.losses)
b = np.maximum(a-self.num, 0)
c = self.losses[b:]
return torch.mean(torch.stack(c))
def set_seed(option):
seed = option['seed']
random.seed(seed)
np.random.seed(seed)
torch.cuda.set_device(option['device'])
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.enabled = True
# torch.backends.cudnn.benchmark = True