-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining.py
401 lines (315 loc) · 19.8 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
# Some modules are from https://github.com/maxiaoba/GRAPE
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from models.gen_imp import get_gen_imp
from models.imputaion_model import LinearHead, LLMHead
from utils import produce_NA, get_main_device, compute_LLM_generation_metrics
import time
import matplotlib.pyplot as plt
from data_loader import load_data
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
import copy
from tqdm import tqdm
class HyperBatch:
def __init__(self, train_hyper_node, hyperedge, train_ve_affiliation, train_labels, batch, train_tokens_emb):
self.train_hyper_node = train_hyper_node
self.hyperedge = hyperedge
self.train_ve_affiliation = train_ve_affiliation
self.train_labels = train_labels
self.batch = batch
self.train_tokens_emb = train_tokens_emb
@staticmethod
def from_data_list(ids, train_hyper_node_all, hyperedge_all, train_ve_affiliation_all, train_labels_all, train_tokens_emb_all):
batch_train_hyper_node = []
batch_hyperedge = []
batch_train_ve_affiliation = []
batch_train_labels = []
batch_indicator = []
batch_train_tokens_emb = []
cumulative_edge = 0
for i in range(len(ids)):
num_edge = hyperedge_all[ids[i]].size(0)
num_node = train_hyper_node_all[ids[i]].size(0)
# hyper_node
batch_train_hyper_node.append(train_hyper_node_all[ids[i]][:int(num_node/2)])
# batch_train_hyper_node.append(train_hyper_node_all[ids[i]])
# train_tokens_emb, LinearHead does not have train_tokens_emb
if len(train_tokens_emb_all)>0:
batch_train_tokens_emb.append(train_tokens_emb_all[ids[i]])
# hyper_node
batch_hyperedge.append(hyperedge_all[ids[i]])
train_ve_affiliation = train_ve_affiliation_all[ids[i]][:, :int(num_node/2)] + cumulative_edge
# train_ve_affiliation = train_ve_affiliation_all[ids[i]]+ cumulative_edge
batch_train_ve_affiliation.append(train_ve_affiliation)
batch_train_labels.append(train_labels_all[ids[i]])
batch_indicator.append(torch.full((num_edge,), i, dtype=torch.long))
cumulative_edge += num_edge
train_hyper_node = torch.cat(batch_train_hyper_node, dim=0)
hyperedge = torch.cat(batch_hyperedge, dim=0)
train_ve_affiliation = torch.cat(batch_train_ve_affiliation, dim=1)
train_labels = torch.cat(batch_train_labels, dim=0)
batch = torch.cat(batch_indicator)
if len(batch_train_tokens_emb) > 0:
train_tokens_emb = torch.cat(batch_train_tokens_emb, dim=0)
else:
train_tokens_emb = []
# undirected
train_ve_affiliation_reverse = train_ve_affiliation[[1, 0], :]
train_ve_affiliation = torch.cat([train_ve_affiliation, train_ve_affiliation_reverse], dim=1)
train_hyper_node = torch.cat([train_hyper_node, train_hyper_node], dim=0)
return HyperBatch(train_hyper_node, hyperedge, train_ve_affiliation, train_labels, batch, train_tokens_emb)
def to(self, device):
self.train_hyper_node = self.train_hyper_node.to(device)
self.hyperedge = self.hyperedge.to(device)
self.train_ve_affiliation = self.train_ve_affiliation.to(device)
self.train_labels = self.train_labels.to(device)
self.batch = self.batch.to(device)
return self
# Generate Imputation Through Auto-Regressive
def generate_impute(args, embedding, impute_model, test_ve_affiliation, lm_model, tokenizer, x_text_test, max_new_tokens=16):
impute_model.eval()
lm_model.eval()
batch_size = len(x_text_test)
if batch_size == 0:
return [], []
inputs = tokenizer(x_text_test, padding=True, truncation=True, return_tensors="pt")
device = get_main_device(lm_model)
inputs = {k: v.to(device) for k, v in inputs.items()}
new_text = []
with torch.no_grad():
generated = inputs['input_ids']
attention_mask = inputs['attention_mask']
for _ in range(max_new_tokens):
generated = generated.to(device)
attention_mask = attention_mask.to(device)
outputs = lm_model.model(
input_ids=generated,
attention_mask=attention_mask,
return_dict=True
)
hidden_states = outputs.last_hidden_state.to(f'cuda:{args.device}')
logits = impute_model([embedding[test_ve_affiliation[0]], embedding[test_ve_affiliation[1]]], hidden_states)
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
generated = torch.cat([generated, next_token.unsqueeze(-1).to(device)], dim=-1)
new_text.append(next_token.unsqueeze(-1).to(device))
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), device=device)], dim=-1)
if all(next_token == tokenizer.eos_token_id):
break
# Concatenate new_text tensors along the sequence dimension
new_text_tensor = torch.cat(new_text, dim=1)
return [tokenizer.decode(gen.squeeze(), skip_special_tokens=True) for gen in new_text_tensor], [tokenizer.decode(gen, skip_special_tokens=True) for gen in generated]
def plot(data, figure_name):
plt.figure(figsize=(10, 6))
plt.plot(data)
min_value = min(data)
last_value = data[-1]
min_index = data.index(min_value)
last_index = len(data) - 1
plt.annotate(f'Min: {min_value}',
xy=(min_index, min_value),
xytext=(0.05, 0.95),
textcoords='axes fraction',
arrowprops=dict(facecolor='blue', shrink=0.05))
plt.annotate(f'Last: {last_value}',
xy=(last_index, last_value),
xytext=(0.95, 0.95),
textcoords='axes fraction',
ha='right',
arrowprops=dict(facecolor='red', shrink=0.05))
plt.title(figure_name)
plt.xlabel('Index')
plt.ylabel('Value')
plt.savefig("./figures/"+figure_name+".png")
plt.close()
def train_model(args, device=torch.device('cpu')):
lm_model, tokenizer, hyperedge, train_hyper_node, train_ve_affiliation, train_labels, test_hyper_node, test_ve_affiliation, test_labels, dataset, chunk_map, train_tokens_emb_LLM, test_node_text = load_data(args)
model = get_gen_imp(hyperedge[0].shape[1], train_hyper_node[0].shape[1], args).to(device)
if args.header_type == "Linear":
# impute_hiddens = list(map(int,args.impute_hiddens.split('_')))
impute_hiddens = hyperedge[0].shape[1]
input_dim = args.hyperedge_dim_hidden * 2
output_dim = 1
impute_model = LinearHead(input_dim, output_dim,
hidden_layer_sizes=impute_hiddens,
hidden_activation=args.impute_activation,
dropout=args.dropout).to(device)
test_labels_all = [item.clone().detach() for item in test_labels]
if args.load_model_name != "None":
model.load_state_dict(torch.load(f"./saved_models/llm_gnn_model_{args.load_model_name}.pth"))
impute_model.load_state_dict(torch.load(f"./saved_models/llm_impute_model_{args.load_model_name}.pth"))
elif args.header_type == "LLM":
impute_hiddens = hyperedge[0].shape[1]
input_dim = args.hyperedge_dim_hidden * 2
output_dim = args.vocab_size # vocab_size for LlamaLite
impute_model = LLMHead(input_dim, output_dim,
hidden_layer_sizes=impute_hiddens,
hidden_activation=args.impute_activation,
dropout=args.dropout,
relation_type=args.relation_type)
impute_model.lm_head.weight.data = lm_model.lm_head.weight.data.clone()
if lm_model.lm_head.bias is not None:
impute_model.lm_head.bias.data = lm_model.lm_head.bias.data.clone()
else:
impute_model.lm_head.bias.data = torch.zeros_like(impute_model.lm_head.bias.data)
impute_model = impute_model.to(device)
test_labels_all = [copy.deepcopy(item) for item in test_labels]
if args.load_model_name != "None":
model.load_state_dict(torch.load(f"./saved_models/llm_gnn_model_{args.load_model_name}.pth"))
impute_model.load_state_dict(torch.load(f"./saved_models/llm_impute_model_{args.load_model_name}.pth"))
trainable_parameters = list(model.parameters()) \
+ list(impute_model.parameters())
print(model)
print(impute_model)
trainable_params = [p for p in model.parameters() if p.requires_grad]
trainable_params_impute = [p for p in impute_model.parameters() if p.requires_grad]
print('total trainable params in GNN model:', sum(p.numel() for p in trainable_params))
print('total trainable params in impute model:', sum(p.numel() for p in trainable_params_impute))
filter_fn = filter(lambda p : p.requires_grad, trainable_parameters)
# optimizer = torch.optim.AdamW(filter_fn, lr=args.lr, weight_decay=args.weight_decay)
optimizer = torch.optim.AdamW(filter_fn, lr=args.lr, weight_decay=args.weight_decay)
train_hyper_node_all = [item.clone().detach() for item in train_hyper_node]
hyperedge_all = [item.clone().detach() for item in hyperedge]
train_ve_affiliation_all = [item.clone().detach() for item in train_ve_affiliation]
train_labels_all = [item.clone().detach() for item in train_labels]
test_hyper_node_all = [item.clone().detach() for item in test_hyper_node]
test_ve_affiliation_all = [item.clone().detach() for item in test_ve_affiliation]
train_tokens_emb_all = [item.clone().detach() for item in train_tokens_emb_LLM]
start_time = time.time()
loss_all = [[] for i in range(len(dataset))]
rmse_all = [[] for i in range(len(dataset))]
mae_all = [[] for i in range(len(dataset))]
# print(chunk_map)
train_ids = [i for i in range(len(chunk_map))]
train_loader = DataLoader(train_ids, batch_size=args.chunk_batch, shuffle=False)
train_loss = 0
for epoch in tqdm(range(args.epochs)):
# for epoch in range(args.epochs):
for ids in train_loader:
batch = HyperBatch.from_data_list(ids,
train_hyper_node_all, hyperedge_all,
train_ve_affiliation_all, train_labels_all, train_tokens_emb_all)
train_hyper_node = batch.train_hyper_node.to(device)
hyperedge = batch.hyperedge.to(device)
train_ve_affiliation = batch.train_ve_affiliation.to(device)
train_labels = batch.train_labels.to(device)
model.train()
impute_model.train()
optimizer.zero_grad()
known_mask = produce_NA(train_hyper_node[:int(train_hyper_node.shape[0]/2)], p_miss=1-args.known, mecha="Random")
known_mask_dup = torch.cat((known_mask, known_mask), dim=0)
known_hyper_node = train_hyper_node.clone().detach()
known_ve_affiliation = train_ve_affiliation.clone().detach()
known_hyper_node = known_hyper_node[known_mask_dup]
known_ve_affiliation = known_ve_affiliation[:,known_mask_dup]
embedding, hyper_node = model(hyperedge, known_hyper_node, known_ve_affiliation)
if args.header_type == "Linear":
train_tokens_emb = batch.train_tokens_emb
pred = impute_model([embedding[train_ve_affiliation[0, :int(train_hyper_node.shape[0]/2)]], embedding[train_ve_affiliation[1, :int(train_hyper_node.shape[0]/2)]]], train_tokens_emb)
pred_train = pred[:int(train_hyper_node.shape[0] / 2),0]
label_train = train_labels
huber_loss = torch.nn.HuberLoss(delta=1)
loss = huber_loss(pred_train, label_train)
elif args.header_type == "LLM":
train_tokens_emb = batch.train_tokens_emb.to(device)
# print(f"the shape of train_tokens_emb is : {train_tokens_emb.shape}, the shape of embedding is : {embedding.shape}, the shape of train_ve_affiliation is : {train_ve_affiliation.shape}, the shape of train_labels is : {train_labels.shape}")
pred = impute_model([embedding[train_ve_affiliation[0, :int(train_hyper_node.shape[0]/2)]], embedding[train_ve_affiliation[1, :int(train_hyper_node.shape[0]/2)]]], train_tokens_emb)
pred_train = pred[:int(train_hyper_node.shape[0] / 2)]
label_train = train_labels
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(pred_train.view(-1, pred_train.size(-1)), label_train.view(-1))
loss.backward()
optimizer.step()
train_loss = loss.item()
if (epoch+1) % args.eval_epoch_gap ==0:
model.eval()
impute_model.eval()
torch.save(model.state_dict(), f"./saved_models/llm_gnn_model_{args.header_type}_Epoch{epoch}_{args.save_name}.pth")
torch.save(impute_model.state_dict(), f"./saved_models/llm_impute_model_{args.header_type}_Epoch{epoch}_{args.save_name}.pth")
with torch.no_grad():
for k in range(len(dataset)):
dataset_chunk = [item==k for item in chunk_map]
pred_test_all = []
label_test_all = []
query_test_all = []
for i in range(len(chunk_map)):
if not dataset_chunk[i]:
continue
train_hyper_node = train_hyper_node_all[i].to(device)
hyperedge = hyperedge_all[i].to(device)
train_ve_affiliation = train_ve_affiliation_all[i].to(device)
test_hyper_node = test_hyper_node_all[i].to(device)
test_ve_affiliation = test_ve_affiliation_all[i].to(device)
test_labels = test_labels_all[i]
embedding, hyper_node = model(hyperedge, train_hyper_node, train_ve_affiliation)
if args.header_type == "Linear":
pred = impute_model([embedding[test_ve_affiliation[0], :], embedding[test_ve_affiliation[1], :]], token_emb=[])
pred_test_all.append(pred[:int(test_hyper_node.shape[0] / 2),0])
label_test_all.append(test_labels.to(device))
elif args.header_type == "LLM":
x_text_test = test_node_text[i]
half_length = len(x_text_test) // 2
x_text_test = x_text_test[:half_length]
test_ve_affiliation = test_ve_affiliation[:, :half_length]
# print(f"the x_text_test is : {x_text_test}")
# for n in range(len(x_text_test)):
# print(f"the query is : {x_text_test[n]}")
new_results, full_results = generate_impute(args, embedding, impute_model, test_ve_affiliation,
lm_model, tokenizer, x_text_test, max_new_tokens=5)
pred_test_all.append(new_results)
# pred_test_all.append(x_text_test)
label_test_all.append(test_labels[:half_length])
query_test_all.append(x_text_test)
if args.header_type == "Linear":
pred_test = torch.cat(pred_test_all)
label_test = torch.cat(label_test_all)
mse = F.mse_loss(pred_test, label_test)
test_rmse = np.sqrt(mse.item())
l1 = F.l1_loss(pred_test, label_test)
test_l1 = l1.item()
print(f"=== {dataset[k]}, the pred_test size is : {pred_test.shape}, the label_test size is : {label_test.shape} ===")
print('epoch: ', epoch)
print('loss: ', train_loss)
print('test rmse: ', test_rmse)
print('test l1: ', test_l1)
print(f"training time is : {time.time()-start_time:.4g}s")
loss_all[k].append(train_loss)
rmse_all[k].append(test_rmse)
mae_all[k].append(test_l1)
elif args.header_type == "LLM":
print(f"=== In the LLM, epoch: {epoch}, dataset: {dataset[k]} ===")
print(f"training time is : {time.time()-start_time:.4g}s")
# print('loss: ', train_loss)
# print(pred_test_all[0][0:5], label_test_all[0][0:5])
# for m in range(len(pred_test_all)):
# for n in range(len(pred_test_all[m])):
# print("============")
# print(f"m is {m}, n is {n}, the length of pred_test_all is : {len(pred_test_all)}, the length of label_test_all is : {len(label_test_all)}, the length of test_node_text is : {len(test_node_text)}, the len(chunk_map) is : {len(chunk_map)}, the length of pred_test_all[m] is : {len(pred_test_all[m])}, the length of test_node_text[m] is : {len(test_node_text[m])}, the length of label_test_all[m] is : {len(label_test_all[m])}")
# print(f"the total query is : {query_test_all[m]}")
# print(f"the query is : {query_test_all[m][n]}")
# print(f"the pred is : {pred_test_all[m][n]}")
# print(f"the label is : {label_test_all[m][n]}")
avg_bleu, avg_rouge_1, avg_rouge_l, avg_rouge_lsum, avg_rouge_w, avg_rouge_s, avg_jaccard, avg_levenshtein, avg_cosine, avg_cosine_tf, avg_cosine_tfidf, avg_cosine_word_embeddings = compute_LLM_generation_metrics(pred_test_all, label_test_all)
print(f"Average BLEU Score: {avg_bleu:.4f}")
print(f"Average ROUGE-1 Score: {avg_rouge_1:.4f}")
print(f"Average ROUGE-L Score: {avg_rouge_l:.4f}")
print(f"Average ROUGE-Lsum Score: {avg_rouge_lsum:.4f}")
print(f"Average ROUGE-W Score: {avg_rouge_w:.4f}")
print(f"Average ROUGE-S Score: {avg_rouge_s:.4f}")
print(f"Average Jaccard Similarity: {avg_jaccard:.4f}")
print(f"Average Levenshtein Distance: {avg_levenshtein:.4f}")
print(f"Average Cosine Similarity: {avg_cosine:.4f}")
print(f"Average Cosine Similarity (TF): {avg_cosine_tf:.4f}")
print(f"Average Cosine Similarity (TF-IDF): {avg_cosine_tfidf:.4f}")
print(f"Average Cosine Similarity (Word Embeddings): {avg_cosine_word_embeddings:.4f}")
if args.header_type == "Linear":
print(f"==== Dataset: {dataset} =====")
print(f"Training Time taken: : {time.time()-start_time:.4g}s")
print(f"Missing Mechanism: {args.missing_mechanism}, miss_rate: {args.missing_ratio}, RMSE: {rmse_all[-1][-1]}, MAE: {mae_all[-1][-1]}")
for k in range(len(dataset)):
plot(loss_all[k], dataset[k]+"_"+args.plot_name+"_loss_all")
plot(rmse_all[k], dataset[k]+"_"+args.plot_name+"_rmse_all")
plot(mae_all[k], dataset[k]+"_"+args.plot_name+"_mae_all")