-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtraining.py
148 lines (101 loc) · 4.64 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#https://arxiv.org/abs/1503.08895 MemN2N
import MemN2N as mn
import babi_process as bp
import numpy as np
import tensorflow as tf # 1.4
import os
data_path = './tasks_1-20_v1-2.tar/tasks_1-20_v1-2/en-10k/' # version 1.2
print("data_path", data_path)
saver_path = "./saver/"
hop = 3
memory_capacity = 50 #논문 4.2 Training Details. The capacity of memory is restricted to the most recent 50 sentences.
vali_ratio = 0.1 # 10% 논문 4.2
embedding_size = 50#50 # joint training은 50, independent training은 20
lr = 0.01
def data_read_and_preprocess(data_path, memory_capacity, vali_ratio):
# data_read
train, test = [], []
for i in range(1, 21):
train.append(bp.data_get(data_path, data_num=i, dataset='train', memory_capacity=memory_capacity))
test.append(bp.data_get(data_path, data_num=i, dataset='test', memory_capacity=memory_capacity))
# data_split
train, vali = bp.train_vali_split(train, vali_ratio)
# get_information
word_dict, rev_word_dict, maximum_word_in_sentence = bp.get_word_dict_and_maximum_word_in_sentence(train+vali+test)
# preprocess (vectorize)
train = bp.data_to_vector(train, word_dict, maximum_word_in_sentence, memory_capacity)
vali = bp.data_to_vector(vali, word_dict, maximum_word_in_sentence, memory_capacity)
test = bp.data_to_vector(test, word_dict, maximum_word_in_sentence, memory_capacity)
return train, vali, test, word_dict, rev_word_dict, maximum_word_in_sentence
def merge_tasks(data):
return np.array(bp.merge_tasks(data))
def toNumpy(data, dtype=np.int32):
return np.array(data.tolist(), dtype)
def train(model, data):
batch_size = 32
loss = 0
np.random.shuffle(data)
for i in range( int(np.ceil(len(data)/batch_size)) ):
#print(i+1, '/', int(np.ceil(len(data)/batch_size)) )
batch = data[batch_size * i: batch_size * (i + 1)]
story = toNumpy(batch[:, 0], np.int32)
question = toNumpy(batch[:, 1], np.int32)
answer = toNumpy(batch[:, 2], np.int64).flatten()
train_loss, _ = sess.run([model.cost, model.minimize], {model.story:story, model.question:question, model.answer:answer})
loss += train_loss
return loss/len(data)
def validation(model, data):
batch_size = 128
loss = 0
for i in range( int(np.ceil(len(data)/batch_size)) ):
batch = data[batch_size * i: batch_size * (i + 1)]
story = toNumpy(batch[:, 0], np.int32)
question = toNumpy(batch[:, 1], np.int32)
answer = toNumpy(batch[:, 2], np.int64).flatten()
vali_loss = sess.run(model.cost, {model.story:story, model.question:question, model.answer:answer})
loss += vali_loss
return loss/len(data)
def test(model, data):
batch_size = 128
correct = 0
for i in range( int(np.ceil(len(data)/batch_size)) ):
batch = data[batch_size * i: batch_size * (i + 1)]
story = toNumpy(batch[:, 0], np.int32)
question = toNumpy(batch[:, 1], np.int32)
answer = toNumpy(batch[:, 2], np.int64).flatten()
check = sess.run(model.correct_check, {model.story:story, model.question:question, model.answer:answer})
correct += check
#check, pred = sess.run([model.correct_check, tf.argmax(model.pred, axis=1)], {model.sentence:sentence, model.question:question, model.y:y})
#correct += check
#print('target', y,rev_word_dict[y[0]], '\tpred', pred, rev_word_dict[pred[0]], '\tcorrect', correct, '\tbatch_epoch', i+1)
return correct/len(data)
def run(model, merge_train, merge_vali, merge_test, task_test, restore=0):
if not os.path.exists(saver_path):
os.makedirs(saver_path)
if restore != 0:
model.saver.restore(sess, saver_path+str(restore)+".ckpt")
for epoch in range(restore+1, 2000+1):
# lr annealing
if epoch <= 20 and epoch % 5 == 0:
model.lr /= 2
# train, vali, test
train_loss = train(model, merge_train)
vali_loss = validation(model, merge_vali)
accuracy = test(model, merge_test)
print("epoch:", epoch, "\ttrain_loss:", train_loss, "\tvali_loss:", vali_loss, "\taccuracy:", accuracy)
# task test
for index, task in enumerate(task_test):
accuracy = test(model, np.array(task))
print(index+1, accuracy)
#weight save
model.saver.save(sess, saver_path+str(epoch)+'.ckpt')
(task_train, task_vali, task_test,
word_dict, rev_word_dict, maximum_word_in_sentence) = data_read_and_preprocess(data_path, memory_capacity, vali_ratio)
merge_train = merge_tasks(task_train)
del task_train
merge_vali = merge_tasks(task_vali)
del task_vali
merge_test = merge_tasks(task_test)
sess = tf.Session()
model = mn.MemN2N(sess, hop, maximum_word_in_sentence, len(word_dict), embedding_size, memory_capacity, lr=lr)
run(model, merge_train, merge_vali, merge_test, task_test)