-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathref_train.py
180 lines (151 loc) · 6.73 KB
/
ref_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
Train the Coattention Network for Query Answering
"""
import os
import json
import torch
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
import ref_networks as N
from data_util.vocab import get_glove
from data_util.data_batcher import get_batch_generator
from data_util.evaluate import exact_match_score, f1_score
from preprocessing.download_wordvecs import maybe_download
from config import Config
from data_utils import get_data
config = Config()
# Embeddings and word2id and id2word
glove_path = os.path.join(config.vectors_cache, "glove.6B.{}d.txt".format(config.embedding_dim))
if not os.path.exists(glove_path):
print("\nDownloading wordvecs to {}".format(config.vectors_cache))
if not os.path.exists(config.vectors_cache):
os.makedirs(config.vectors_cache)
maybe_download(config.glove_base_url, config.glove_filename, config.vectors_cache, 862182613)
emb_matrix, word2index, index2word = get_glove(glove_path, config.embedding_dim)
train_context_path = os.path.join(config.data_dir, "train.context")
train_qn_path = os.path.join(config.data_dir, "train.question")
train_ans_path = os.path.join(config.data_dir, "train.span")
dev_context_path = os.path.join(config.data_dir, "dev.context")
dev_qn_path = os.path.join(config.data_dir, "dev.question")
dev_ans_path = os.path.join(config.data_dir, "dev.span")
# @timeit
def step(model, optimizer, batch, params):
"""
One batch of training
:return: loss
"""
# Here goes one batch of training
q_seq, q_mask, d_seq, d_mask, target_span = get_data(batch, config.mode.lower() == 'train')
model.zero_grad()
# The loss is individual loss for each pair of question, context and answer
loss, _, _ = model(q_seq, q_mask, d_seq, d_mask, target_span)
loss = torch.sum(loss)
loss.backward(retain_graph=True)
clip_grad_norm_(params, config.max_grad_norm)
optimizer.step()
return loss
# @timeit
def evaluate(model, batch):
"""
Evaluate the training and test set accuracy
:return:
"""
# Here goes one batch of training
q_seq, q_mask, d_seq, d_mask, target_span = get_data(batch, config.mode.lower() == 'train')
with torch.no_grad():
# The loss is individual loss for each pair of question, context and answer
loss, start_pos_pred, end_pos_pred = model(q_seq, q_mask, d_seq, d_mask, target_span)
start_pos_pred = start_pos_pred.tolist()
end_pos_pred = end_pos_pred.tolist()
f1 = 0
for i, (pred_ans_start, pred_ans_end, true_ans_tokens) in enumerate(
zip(start_pos_pred, end_pos_pred, batch.ans_tokens)):
pred_ans_tokens = batch.context_tokens[i][pred_ans_start: pred_ans_end + 1]
prediction = " ".join(pred_ans_tokens)
ground_truth = " ".join(true_ans_tokens)
f1 += f1_score(prediction, ground_truth)
f1 = f1 / (i + 1)
return f1
def train(context_path, qn_path, ans_path):
""" Train the network """
model = N.CoattentionModel(
hidden_dim=config.hidden_size,
maxout_pool_size=config.max_pool_size,
emb_matrix=emb_matrix,
max_dec_steps=config.max_dec_steps,
dropout_ratio=config.fusion_dropout_rate
)
# Select the parameters which require grad / backpropagation
params = list(filter(lambda p: p.requires_grad, model.parameters()))
optimizer = optim.SGD(params, lr=config.learning_rate, weight_decay=config.l2_norm)
# Set up directories for this experiment
if not os.path.exists(config.experiments_root_dir):
os.makedirs(config.experiments_root_dir)
serial_number = len(os.listdir(config.experiments_root_dir))
if config.restore:
serial_number -= 1 # Check into the latest model
experiment_dir = os.path.join(config.experiments_root_dir, 'experiment_{}'.format(serial_number))
if not os.path.exists(experiment_dir):
os.makedirs(experiment_dir)
model_dir = os.path.join(experiment_dir, 'model')
if not os.path.exists(model_dir):
os.makedirs(model_dir)
bestmodel_dir = os.path.join(experiment_dir, 'bestmodel')
if not os.path.exists(bestmodel_dir):
os.makedirs(bestmodel_dir)
# Save config as config.json
with open(os.path.join(experiment_dir, "config.json"), 'w') as fout:
json.dump(vars(config), fout)
iteration = 0
if config.restore:
saved_models = os.listdir(model_dir)
if len(saved_models):
print(saved_models)
saved_models = [int(name.split('-')[-1]) for name in saved_models]
latest_iter = max(saved_models)
checkpoint_name = "checkpoint-embed{}-iter-{}".format(config.embedding_dim, latest_iter)
checkpoint_name = os.path.join(model_dir, checkpoint_name)
state = torch.load(checkpoint_name)
model.load_state_dict(state['model'])
optimizer.load_state_dict(state['optimizer'])
iteration = state['iter']
print("Model restored from ", checkpoint_name)
else:
print("Training with fresh parameters")
for batch in get_batch_generator(word2index, context_path, qn_path, ans_path,
config.batch_size, config.context_len,
config.question_len, discard_long=True):
# When the batch is partially filled, ignore it.
if batch.batch_size < config.batch_size:
continue
# Take step in training
loss = step(model, optimizer, batch, params)
# Displaying results
if iteration % config.evaluate_every == 0 and iteration % config.print_every != 0:
print("Iter {}\t\tloss : {}\tf1 : {}".format(iteration, "%.2f" % loss, "%.2f" % -1))
if iteration % config.evaluate_every == 0:
f1 = evaluate(model, batch)
print("Iter {}\t\tloss : {}\tf1 : {}".format(iteration, "%.2f" % loss, "%.2f" % f1))
# Maybe you want to do random evaluations as well for sanity check
# Saving the model
if iteration % config.save_every == 0:
state = {
'iter': iteration,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'loss': loss
}
checkpoint_name = "checkpoint-embed{}-iter-{}".format(config.embedding_dim, iteration)
fname = os.path.join(model_dir, checkpoint_name)
torch.save(state, fname)
iteration += 1
if __name__ == '__main__':
if config.mode == 'train':
context_path = train_context_path
qn_path = train_qn_path
ans_path = train_ans_path
else:
context_path = dev_context_path
qn_path = dev_qn_path
ans_path = dev_ans_path
train(context_path, qn_path, ans_path)