-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathearly_stopping.py
40 lines (34 loc) · 1.31 KB
/
early_stopping.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class EarlyStopping:
def __init__(self, patience=100):
self.best_fitness = float('inf')
self.best_epoch = 0
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
@staticmethod
def mean(x):
return sum(x) / len(x)
@staticmethod
def log(path, line):
logger = open(path, "a")
logger.write(line + '\n')
logger.close()
def __call__(self, epoch, fitness):
if fitness <= self.best_fitness:
self.best_epoch = epoch
self.best_fitness = fitness
delta = epoch - self.best_epoch # epochs without improvement
stop = delta >= self.patience # stop training if patience exceeded
if stop:
print(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
f'Best results observed at epoch {self.best_epoch}')
return stop
if __name__ == '__main__':
patience = 10
early_stopper = EarlyStopping(patience)
path = 'training/3/val_log.txt'
file = open(path, 'r')
lines = file.readlines()
loss = [float(str(i).split(' ')[0].split(' ')[2][1:-1]) for i in lines]
for epoch, loss in enumerate(loss):
stop = early_stopper(epoch, loss)
if stop:
print(epoch)