-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
192 lines (175 loc) · 11.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import argparse
import logging
import torch
from warnings import simplefilter
from scipy.sparse import SparseEfficiencyWarning
from datetime import datetime
from subgraph_extraction.datasets import SubgraphDataset, generate_subgraph_datasets, generate_subgraph_datasets_test
from utils.initialization_utils import initialize_experiment, initialize_model
from utils.graph_utils import collate_dgl, move_batch_to_device_dgl, collate_dgl2
from model.graph_model import GraphClassifier as dgl_model
from manager.evaluator import Evaluator
from manager.trainer import Trainer
def main(params):
startime = datetime.now()
logging.info("startime:"+str(startime))
simplefilter(action='ignore', category=UserWarning)
simplefilter(action='ignore', category=SparseEfficiencyWarning)
params.db_path = os.path.join(params.main_dir,
f'data/{params.dataset}/subgraphs_en_{params.enclosing_sub_graph}_neg_{params.num_neg_samples_per_link}_hop_{params.hop}')
if not os.path.isdir(params.db_path):
generate_subgraph_datasets(params)
train_one = SubgraphDataset(params.db_path, 'one_train_pos', 'one_train_neg', params.file_paths,
add_traspose_rels=params.add_traspose_rels,
num_neg_samples_per_link=params.num_neg_samples_per_link,
use_kge_embeddings=params.use_kge_embeddings, dataset=params.dataset,
kge_model=params.kge_model, file_name=params.train_file)
valid_one = SubgraphDataset(params.db_path, 'one_valid_pos', 'one_valid_neg', params.file_paths,
add_traspose_rels=params.add_traspose_rels,
num_neg_samples_per_link=params.num_neg_samples_per_link,
use_kge_embeddings=params.use_kge_embeddings, dataset=params.dataset,
kge_model=params.kge_model, file_name=params.valid_file)
train_one2 = SubgraphDataset(params.db_path, 'one2_train_pos', 'one2_train_neg', params.file_paths,
add_traspose_rels=params.add_traspose_rels,
num_neg_samples_per_link=params.num_neg_samples_per_link,
use_kge_embeddings=params.use_kge_embeddings, dataset=params.dataset,
kge_model=params.kge_model, file_name=params.train_file)
valid_one2 = SubgraphDataset(params.db_path, 'one2_valid_pos', 'one2_valid_neg', params.file_paths,
add_traspose_rels=params.add_traspose_rels,
num_neg_samples_per_link=params.num_neg_samples_per_link,
use_kge_embeddings=params.use_kge_embeddings, dataset=params.dataset,
kge_model=params.kge_model, file_name=params.valid_file)
train_many = SubgraphDataset(params.db_path, 'many_train_pos', 'many_train_neg', params.file_paths,
add_traspose_rels=params.add_traspose_rels,
num_neg_samples_per_link=params.num_neg_samples_per_link,
use_kge_embeddings=params.use_kge_embeddings, dataset=params.dataset,
kge_model=params.kge_model, file_name=params.train_file)
valid_many = SubgraphDataset(params.db_path, 'many_valid_pos', 'many_valid_neg', params.file_paths,
add_traspose_rels=params.add_traspose_rels,
num_neg_samples_per_link=params.num_neg_samples_per_link,
use_kge_embeddings=params.use_kge_embeddings, dataset=params.dataset,
kge_model=params.kge_model, file_name=params.valid_file)
train_many2 = SubgraphDataset(params.db_path, 'many2_train_pos', 'many2_train_neg', params.file_paths,
add_traspose_rels=params.add_traspose_rels,
num_neg_samples_per_link=params.num_neg_samples_per_link,
use_kge_embeddings=params.use_kge_embeddings, dataset=params.dataset,
kge_model=params.kge_model, file_name=params.train_file)
valid_many2 = SubgraphDataset(params.db_path, 'many2_valid_pos', 'many2_valid_neg', params.file_paths,
add_traspose_rels=params.add_traspose_rels,
num_neg_samples_per_link=params.num_neg_samples_per_link,
use_kge_embeddings=params.use_kge_embeddings, dataset=params.dataset,
kge_model=params.kge_model, file_name=params.valid_file)
params.num_rels = train_one.num_rels
params.aug_num_rels = train_one.aug_num_rels
params.inp_dim = train_one.n_feat_dim
params.max_label_value = train_one.max_n_label
graph_classifier = initialize_model(params, dgl_model, params.load_model)
logging.info(f"Device: {params.device}")
logging.info(
f"Input dim : {params.inp_dim}, # Relations : {params.num_rels}, # Augmented relations : {params.aug_num_rels}")
train = [train_one, train_one2, train_many, train_many, train_many2]
valid = [valid_one, valid_one2, valid_many, valid_many2]
trainer = Trainer(params, graph_classifier, train, valid)
logging.info('Starting training with full batch...')
trainer()
endtime = datetime.now()
logging.info("endtime:"+str(endtime))
runningtime = endtime - startime
logging.info("runingtime:"+str(runningtime))
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description='No model')
parser.add_argument("--experiment_name", "-e", type=str, default="WN18RR_v1",
help="A folder with this name would be created to dump saved models and log files")
parser.add_argument("--dataset", "-d", type=str, default="WN18RR_v1",
help="Dataset string")
parser.add_argument("--gpu", type=int, default=0,
help="Which GPU to use?")
parser.add_argument('--disable_cuda', action='store_true',
help='Disable CUDA')
parser.add_argument('--load_model', action='store_true',
help='Load existing model?')
parser.add_argument("--train_file", "-tf", type=str, default="train",
help="Name of file containing training triplets")
parser.add_argument("--valid_file", "-vf", type=str, default="valid",
help="Name of file containing validation triplets")
parser.add_argument("--num_epochs", "-ne", type=int, default=70,
help="Learning rate of the optimizer")
parser.add_argument("--eval_every", type=int, default=3,
help="Interval of epochs to evaluate the model?")
parser.add_argument("--eval_every_iter", type=int, default=4,
help="Interval of iterations to evaluate the model?")
parser.add_argument("--save_every", type=int, default=1,
help="Interval of epochs to save a checkpoint of the model?")
parser.add_argument("--early_stop", type=int, default=100,
help="Early stopping patience")
parser.add_argument("--optimizer", type=str, default="Adam",
help="Which optimizer to use?")
parser.add_argument("--lr", type=float, default=0.01,
help="Learning rate of the optimizer")
parser.add_argument("--update_lr", type=float, default=0.01,
help="Learning rate of the optimizer")
parser.add_argument("--clip", type=int, default=1000,
help="Maximum gradient norm allowed")
parser.add_argument("--l2", type=float, default=5e-4,
help="Regularization constant for GNN weights")
parser.add_argument("--margin", type=float, default=10,
help="The margin between positive and negative samples in the max-margin loss")
parser.add_argument("--max_links", type=int, default=1000000,
help="Set maximum number of train links (to fit into memory)")
parser.add_argument("--hop", type=int, default=3,
help="Enclosing subgraph hop number")
parser.add_argument("--max_nodes_per_hop", "-max_h", type=int, default=None,
help="if > 0, upper bound the # nodes per hop by subsampling")
parser.add_argument("--use_kge_embeddings", "-kge", type=bool, default=False,
help='whether to use pretrained KGE embeddings')
parser.add_argument("--kge_model", type=str, default=None,
help="Which KGE model to load entity embeddings from")
parser.add_argument('--model_type', '-m', type=str, choices=['ssp', 'dgl'], default='dgl',
help='what format to store subgraphs in for model')
parser.add_argument('--constrained_neg_prob', '-cn', type=float, default=0.0,
help='with what probability to sample constrained heads/tails while neg sampling')
parser.add_argument("--batch_size", type=int, default=16,
help="Batch size")
parser.add_argument("--num_neg_samples_per_link", '-neg', type=int, default=1,
help="Number of negative examples to sample per positive link")
parser.add_argument("--num_workers", type=int, default=1,
help="Number of dataloading processes")
parser.add_argument('--add_traspose_rels', '-tr', type=bool, default=False,
help='whether to append adj matrix list with symmetric relations')
parser.add_argument('--enclosing_sub_graph', '-en', type=bool, default=True,
help='whether to only consider enclosing subgraph')
parser.add_argument("--rel_emb_dim", "-r_dim", type=int, default=32,
help="Relation embedding size")
parser.add_argument("--attn_rel_emb_dim", "-ar_dim", type=int, default=32,
help="Relation embedding size for attention")
parser.add_argument("--emb_dim", "-dim", type=int, default=32,
help="Entity embedding size")
parser.add_argument("--num_gcn_layers", "-l", type=int, default=3,
help="Number of GCN layers")
parser.add_argument("--num_bases", "-b", type=int, default=4,
help="Number of basis functions to use for GCN weights")
parser.add_argument("--dropout", type=float, default=0,
help="Dropout rate in GNN layers")
parser.add_argument("--edge_dropout", type=float, default=0.5,
help="Dropout rate in edges of the subgraphs")
parser.add_argument('--gnn_agg_type', '-a', type=str, choices=['sum', 'mlp', 'gru'], default='sum',
help='what type of aggregation to do in gnn msg passing')
parser.add_argument('--add_ht_emb', '-ht', type=bool, default=True,
help='whether to concatenate head/tail embedding with pooled graph representation')
parser.add_argument('--has_attn', '-attn', type=bool, default=True,
help='whether to have attn in model or not')
params = parser.parse_args()
initialize_experiment(params, __file__)
params.file_paths = {
'train': os.path.join(params.main_dir, 'data/{}/{}.txt'.format(params.dataset, params.train_file)),
'valid': os.path.join(params.main_dir, 'data/{}/{}.txt'.format(params.dataset, params.valid_file))
}
if not params.disable_cuda and torch.cuda.is_available():
params.device = torch.device('cuda:%d' % params.gpu)
else:
params.device = torch.device('cpu')
params.collate_fn = collate_dgl2
params.move_batch_to_device = move_batch_to_device_dgl
main(params)