-
Notifications
You must be signed in to change notification settings - Fork 335
/
Copy pathmain.py
162 lines (131 loc) · 5.36 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import argparse
import json
import logging
import sys
import time
from functools import partial
from pathlib import Path
import tensorflow as tf
from inputs import *
from model_fns import *
from predict_fns import *
# This program was designed to function with multiple kinds of models, but currently only GPT2 is supported
# The first element in the tupel is the model function, the second is the function called when predicting
models = {
"GPT2": (gpt2_model, gpt2_predict)
}
inputs = {
"openwebtext": openwebtext, # Standard OpenWebtext input
"openwebtext_longbiased": openwebtext_longbiased, # OpenWebtext with a bias towards showing more long (>512 tokens) examples
"openwebtext_long": openwebtext_long, # Openwebtext that only shows long examples
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--tpu', type=str) # Name of TPU to train on, if any
parser.add_argument('--model', type=str) # JSON file that contains model parameters
parser.add_argument("--predict_file", type=str) # File to take as input for predict
parser.add_argument("--predict_text", type=str) # Take string directly from args
parser.add_argument("--top_k", type=int) # Top K truncation parameter for text generation
args = parser.parse_args()
# Get prediction text
predict_mode = False
if args.predict_file is not None:
predict_mode = True
with open(args.predict_file) as f:
text = f.read()
elif args.predict_text is not None:
predict_mode = True
text = args.predict_text
elif args.predict_file is not None and args.predict_text is not None:
print("ERROR: Specify exactly one of --predict_file and --predict_text!")
sys.exit()
# Setup logging
Path("logs").mkdir(exist_ok=True)
tf.logging.set_verbosity(logging.INFO)
handlers = [
logging.FileHandler('logs/{}.log'.format(args.model)),
logging.StreamHandler(sys.stdout)
]
logger = logging.getLogger('tensorflow')
logger.handlers = handlers
# Read params of model
with open(args.model, "r") as f:
params = json.load(f)
if not args.tpu is None:
params["use_tpu"] = True
else:
params["use_tpu"] = False
if args.top_k is not None:
params["top_k"] = args.top_k
if not "precision" in params.keys():
params["precision"] = "float32" # Doesn't actually do anything since float32 is the default anyways. Only recognized other dtype is "bfloat16"
if not "iterations" in params.keys():
params["iterations"] = 1 # Because this controls how many samples are prefetched
logger.info(params)
model_fn = models[params["model"]][0]
predict_fn = models[params["model"]][1]
input_fn = inputs[params["input"]]
if params["use_tpu"] and not predict_mode:
# Resolve TPU cluster and runconfig
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(args.tpu)
run_config = tf.contrib.tpu.RunConfig(
model_dir=params["model_path"],
cluster=tpu_cluster_resolver,
save_checkpoints_secs=60*30,
session_config=tf.ConfigProto(
# allow_soft_placement=True,
# log_device_placement=True
),
tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=params["iterations"])
)
# Set up network
network = tf.contrib.tpu.TPUEstimator(
model_fn=model_fn,
use_tpu=True,
train_batch_size=params["train_batch_size"], # These are the global sizes, must be divisible by replicas
eval_batch_size=params["eval_batch_size"],
predict_batch_size=params["predict_batch_size"],
config=run_config,
params=params)
else:
# Non TPU setup
if not predict_mode:
params["batch_size"] = params["train_batch_size"]
else:
params["batch_size"] = params["predict_batch_size"]
from models.gpt2 import encoder
enc = encoder.get_encoder(params["encoder_path"])
tokens = enc.encode(text)
params["text_len"] = len(tokens)
if params["text_len"] > 1024:
params["text_len"] = 1024
run_config = tf.estimator.RunConfig(
model_dir=params["model_path"],
session_config=tf.ConfigProto(
# log_device_placement=True,
# allow_soft_placement=True
),
)
network = tf.estimator.Estimator(
model_fn=model_fn,
config=run_config,
params=params)
if predict_mode:
logger.info("Generating predictions...")
predict_fn(network, text, params)
sys.exit()
# Train eval loop
while True:
start = time.time()
network.train(
input_fn=partial(input_fn, eval=False),
steps=params["train_steps"])
end = time.time()
logger.info("\nTrain loop took {:.2f}s\n".format(end-start))
eval_result = network.evaluate(
input_fn=partial(input_fn, eval=True),
steps=params["eval_steps"])
logger.info("\nEval Results: {}\n".format(str(eval_result)))
if network.get_variable_value("global_step") > params["max_steps"]:
logger.info("Done!")
break