-
Notifications
You must be signed in to change notification settings - Fork 360
/
Copy pathutils.py
66 lines (48 loc) · 1.65 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import tensorflow as tf
import yaml
slim = tf.contrib.slim
def _get_init_fn(FLAGS):
"""
This function is copied from TF slim.
Returns a function run by the chief worker to warm-start the training.
Note that the init_fn is only run when initializing the model during the very
first global step.
Returns:
An init function run by the supervisor.
"""
tf.logging.info('Use pretrained model %s' % FLAGS.loss_model_file)
exclusions = []
if FLAGS.checkpoint_exclude_scopes:
exclusions = [scope.strip()
for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
# TODO(sguada) variables.filter_variables()
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
return slim.assign_from_checkpoint_fn(
FLAGS.loss_model_file,
variables_to_restore,
ignore_missing_vars=True)
class Flag(object):
def __init__(self, **entries):
self.__dict__.update(entries)
def read_conf_file(conf_file):
with open(conf_file) as f:
FLAGS = Flag(**yaml.load(f))
return FLAGS
def mean_image_subtraction(image, means):
image = tf.to_float(image)
num_channels = 3
channels = tf.split(image, num_channels, 2)
for i in range(num_channels):
channels[i] -= means[i]
return tf.concat(channels, 2)
if __name__ == '__main__':
f = read_conf_file('conf/mosaic.yml')
print(f.loss_model_file)