-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgpt2_generator.py
122 lines (91 loc) · 2.87 KB
/
gpt2_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
import time
import os
import numpy as np
from os.path import join
import sys
import pickle
import subprocess
from ghn3.graph import Graph_LLM, GraphBatch
from ppuda.utils.utils import capacity, set_seed
from transformers import AutoTokenizer, AutoConfig
from transformers import GPT2Config, GPT2LMHeadModel, AutoTokenizer
#def main():
try:
split = 'train'
N = int(sys.argv[1])
data_dir = sys.argv[2]
except Exception as e:
print('\nExample of usage: python gpt2_generator.py 10000 ./data\n', e)
raise
try:
gitcommit = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip()
print('gitcommit:', gitcommit, flush=True)
except Exception as e:
print(e, flush=True)
device = 'cpu'
max_params = 50 * 10 ** 6
print(split, N, data_dir, device, flush=True)
if not os.path.exists(data_dir):
os.mkdir(data_dir)
set_seed(1)
dset_name = 'wikitext-2'
h5_file = join(data_dir, 'GPT21K_%s.pkl' % split)
graphs = []
params = []
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
def get_var():
var = model(**tokenizer("Hello, my dog is cute", return_tensors="pt")).logits
return var
while len(graphs) < N:
n_layer = np.random.randint(3, 10)
if n_layer > 5:
dim_min = 72
dim_max = 176
elif n_layer > 3:
dim_min = 128
dim_max = 176
else:
dim_min = 176
dim_max = 256
n_embd = np.random.choice(np.arange(dim_min, dim_max+1, 8))
if n_embd % 8 == 0:
n_head = np.random.choice([8])
elif n_embd % 6 == 0:
n_head = np.random.choice([6])
elif n_embd % 4 == 0:
n_head = 4
config = GPT2Config(
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
n_embd=int(n_embd),
n_layer=int(n_layer),
n_head=int(n_head),
tie_word_embeddings=False,
)
net_args = {'n_embd': n_embd, 'n_layer': n_layer, 'n_head': n_head}
print(net_args, flush=True)
model = GPT2LMHeadModel(config)
# model.get_var = get_var
#print(model.get_var)
n = capacity(model)[1]
if n > max_params:
print('too large archi: %.2f M params \n' % (n / 1e6), flush=True)
continue
params.append(n/1e6)
graph = Graph_LLM(model, tokenizer, ve_cutoff=250, dense=True)
graph.net_args = {'n_embd': n_embd, 'n_layer': n_layer, 'n_head': n_head}
graph.config = config
graphs.append(graph)
print(len(graphs), '%.3f M params' % (n / 1e6), flush=True)
with open(h5_file, 'wb') as f:
pickle.dump(graphs, f)
print('saved to %s' % h5_file)
print('params: %.3f +- %.3f M' % (np.mean(params), np.std(params)), flush=True)
print('\n done')
# if __name__ == '__main__':
# main()