-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathutils.py
125 lines (108 loc) · 4.15 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
This file is for functions to help the translation
"""
import numpy as np
import time
import sys
import json
class Monitor(object):
def __init__(self, root='http://localhost:9000'):
self.root = root
def display(self, batch, logs={}):
import requests
send = {}
send['epoch'] = batch
for k, v in logs.items():
send[k] = v
try:
requests.post(self.root + '/publish/epoch/end/',
{'data': json.dumps(send)})
except:
print('Warning: could not reach RemoteMonitor '
'root server at ' + str(self.root))
class Progbar(object):
def __init__(self, target, width=30, verbose=1, with_history=True):
'''
@param target: total number of steps expected
'''
self.width = width
self.target = target
self.sum_values = {}
self.unique_values = []
self.start = time.time()
self.total_width = 0
self.seen_so_far = 0
self.verbose = verbose
self.with_history = with_history
def update(self, current, values=[]):
'''
@param current: index of current step
@param values: list of tuples (name, value_for_last_step).
The progress bar will display averages for these values.
'''
if not self.with_history:
self.sum_values = {}
self.unique_values = []
for k, v in values:
if k not in self.sum_values:
self.sum_values[k] = [v * (current - self.seen_so_far), current - self.seen_so_far]
self.unique_values.append(k)
else:
self.sum_values[k][0] += v * (current - self.seen_so_far)
self.sum_values[k][1] += (current - self.seen_so_far)
self.seen_so_far = current
now = time.time()
if self.verbose == 1:
prev_total_width = self.total_width
sys.stdout.write("\b" * prev_total_width)
sys.stdout.write("\r")
numdigits = int(np.floor(np.log10(self.target))) + 1
barstr = '%%%dd/%%%dd [' % (numdigits, numdigits)
bar = barstr % (current, self.target)
prog = float(current)/self.target
prog_width = int(self.width*prog)
if prog_width > 0:
bar += ('.'*(prog_width-1))
if current < self.target:
bar += '(-w-)'
else:
bar += '(-v-)!!'
bar += ('~' * (self.width-prog_width))
bar += ']'
sys.stdout.write(bar)
self.total_width = len(bar)
if current:
time_per_unit = (now - self.start) / current
else:
time_per_unit = 0
eta = time_per_unit*(self.target - current)
info = ''
if current < self.target:
info += ' - ETA: %ds' % eta
else:
info += ' - %ds' % (now - self.start)
for k in self.unique_values:
if k == 'perplexity' or k == 'PPL':
info += ' - %s: %.4f' % (k, np.exp(self.sum_values[k][0] / max(1, self.sum_values[k][1])))
else:
info += ' - %s: %.4f' % (k, self.sum_values[k][0] / max(1, self.sum_values[k][1]))
self.total_width += len(info)
if prev_total_width > self.total_width:
info += ((prev_total_width-self.total_width) * " ")
sys.stdout.write(info)
sys.stdout.flush()
if current >= self.target:
sys.stdout.write("\n")
if self.verbose == 2:
if current >= self.target:
info = '%ds' % (now - self.start)
for k in self.unique_values:
info += ' - %s: %.4f' % (k, self.sum_values[k][0] / max(1, self.sum_values[k][1]))
sys.stdout.write(info + "\n")
def add(self, n, values=[]):
self.update(self.seen_so_far + n, values)
def clear(self):
self.sum_values = {}
self.unique_values = []
self.total_width = 0
self.seen_so_far = 0