-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlogz.py
115 lines (96 loc) · 3.51 KB
/
logz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Logging functionality, extending logging from https://github.com/berkeleydeeprlcourse/homework."""
import json
import os.path as osp, time, atexit, os
color2num = dict(
gray=30,
red=31,
green=32,
yellow=33,
blue=34,
magenta=35,
cyan=36,
white=37,
crimson=38
)
def colorize(string, color, bold=False, highlight=False):
attr = []
num = color2num[color]
if highlight:
num += 10
attr.append(str(num))
if bold:
attr.append('1')
return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)
class G:
output_dir = None
output_file = None
output_weights = None
first_row = True
save_it = 0
log_headers = []
log_current_row = {}
def configure_output_dir(d=None, force=False):
"""Set logging directory if provided or to experiments_data/temp/$current_time."""
G.output_dir = d or "experiments_data/temp/{}".format(int(time.time()))
if not force:
assert not osp.exists(
G.output_dir), "Log dir %s already exists! Delete it first or use a different dir" % G.output_dir
G.output_weights = "{}/weights".format(G.output_dir)
os.makedirs(G.output_weights)
G.output_file = open(osp.join(G.output_dir, "log.txt"), 'w')
atexit.register(G.output_file.close)
G.first_row = True
G.save_it = 0
G.log_headers.clear()
G.log_current_row.clear()
print(colorize("Logging data to %s" % G.output_file.name, 'green', bold=True))
def log_tabular(key, value):
"""Log a value of some diagnostic Call this once for each diagnostic quantity, each iteration"""
if G.first_row:
G.log_headers.append(key)
else:
assert key in G.log_headers, "Trying to introduce a new key %s that you didn't include in the first iteration" % key
assert key not in G.log_current_row, "You already set %s this iteration. Maybe you forgot to call dump_tabular()" % key
G.log_current_row[key] = value
def dump_tabular():
"""Write all of the diagnostics from the current iteration."""
vals = []
key_lens = [len(key) for key in G.log_headers]
max_key_len = max(15, max(key_lens))
keystr = '%' + '%d' % max_key_len
fmt = "| " + keystr + "s | %15s |"
n_slashes = 22 + max_key_len
print("-" * n_slashes)
for key in G.log_headers:
val = G.log_current_row.get(key, "")
if hasattr(val, "__float__"):
valstr = "%8.3g" % val
else:
valstr = val
print(fmt % (key, valstr))
vals.append(val)
print("-" * n_slashes)
if G.output_file is not None:
if G.first_row:
G.output_file.write("\t".join(G.log_headers))
G.output_file.write("\n")
G.first_row = False
G.output_file.write("\t".join(map(str, vals)))
G.output_file.write("\n")
G.output_file.flush()
G.log_current_row.clear()
def save_params(params):
"""Save used parameters in JSON file."""
with open(osp.join(G.output_dir, 'params.json'), 'w') as out:
out.write(json.dumps(params, indent=2, separators=(',', ': '), sort_keys=True))
def load_params(dir, filename="params.json"):
"""Load parameters from JSON file."""
with open(osp.join(dir, filename), 'r') as inp:
data = json.loads(inp.read())
return data
def save_tf_weights(model):
"""Save checkpoint weights of a trained model."""
save_dir = osp.join(G.output_weights, '{}.h5'.format(G.save_it))
print(colorize("Saving model weights to %s" % save_dir, 'green'))
model.save_weights(save_dir)
G.save_it += 1