-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhf_sample.py
123 lines (100 loc) · 3.86 KB
/
hf_sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import os
import time
from hf_model import Transformer, ModelArgs
from transformers import CodeLlamaTokenizer
DEVICE = "cuda:3"
DTYPE = torch.bfloat16
GROUP_SIZE = 64
# start by running in int8, then move on to int4
# file = "./CodeLlama-7b-Instruct-hf/pytorch_model-00003-of-00003.bin"
# model_dict = torch.load(file, map_location='cpu', mmap=True)
# print(model_dict.keys())
'''
Name Map:
original | new
================|================
attention_norm | input_layernorm
ffn_norm | post_attention_layernorm
feed_forward.w1 | mlp.gate_proj
feed_forward.w2 | mlp.down_proj
feed_forward.w3 | mlp.up_proj
attention.wq | self_attn.q_proj
attention.wk | self_attn.k_proj
attention.wv | self_attn.v_proj
attention.wo | self_attn.o_proj
norm | norm
output | lm_head
tok_embeddings | embed_tokens
'''
nameMap = {"attention_norm": "input_layernorm",
"ffn_norm": "post_attention_layernorm",
"feed_forward": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"attention": "self_attn",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
"wo": "o_proj",
"norm": "norm",
"output": "lm_head",
"tok_embeddings": "embed_tokens"}
nameMap_reverse = {v: k for k, v in nameMap.items()}
# init from a model saved in a specific directory
def remap_names(file):
model_dict = torch.load(file, map_location='cpu', mmap=True)
unwanted_prefix = 'model.'
for k,v in list(model_dict.items()):
if k.startswith(unwanted_prefix):
model_dict[k[len(unwanted_prefix):]] = model_dict.pop(k)
for k,v in list(model_dict.items()):
split_keys = k.split(".")
for i in range(len(split_keys)):
if split_keys[i] in nameMap_reverse:
split_keys[i] = nameMap_reverse[split_keys[i]]
model_dict[".".join(split_keys)] = model_dict.pop(k)
#for k,v in list(model_dict.items()):
# model_dict[k] = v.to(torch.float16)
return model_dict
model_dir = './CodeLlama-7b-Instruct-hf/'
dir = os.listdir(model_dir)
# access all {x}-of-00003.bin files
print(dir)
model_dict = {}
for file in dir:
if file.startswith("pytorch_model-"):
print("Loading file: ", model_dir + file)
curr_dict = remap_names(model_dir + file)
model_dict.update(curr_dict)
#for k,v in list(model_dict.items()):
#print(k, v.shape, v.dtype)
#w, s = quantize_q40(v, 64)
#dequant = dequantize_q40(w, s, 64, v.shape)
#print("Avg error: ", torch.mean(torch.abs(v - dequant)))
model = Transformer(ModelArgs) #default is llama7B
model.load_state_dict(model_dict, strict=True, assign=True)
model.to(DEVICE, dtype=DTYPE)
model.eval()
tokenizer = CodeLlamaTokenizer.from_pretrained("./CodeLlama-7b-Instruct-hf")
PROMPT = '''[INST] <<SYS>> You are a programmer, write the following python function that passes the given tests
<</SYS>>
Test Cases
assert max_chain_length([Pair(5, 24), Pair(15, 25),Pair(27, 40), Pair(50, 60)], 4) == 3
assert max_chain_length([Pair(1, 2), Pair(3, 4),Pair(5, 6), Pair(7, 8)], 4) == 4
assert max_chain_length([Pair(19, 10), Pair(11, 12),Pair(13, 14), Pair(15, 16), Pair(31, 54)], 5) == 5
Write a function to find the longest chain which can be formed from the given set of pairs.
[/INST]
'''
start = time.time()
print("Generating...")
input_ids = tokenizer(PROMPT, return_tensors="pt")["input_ids"]
input_ids = input_ids.to(DEVICE)
print(input_ids)
print(tokenizer.batch_decode(input_ids, skip_special_tokens = True)[0])
generated_ids = model.generate(input_ids, 512, temperature=0.1, top_k=32, enc=tokenizer.batch_decode)
print("Time taken: ", time.time() - start)
print(generated_ids)
print(tokenizer.batch_decode(generated_ids[:, input_ids.shape[1]:], skip_special_tokens = True)[0])
# 29961, 20055, 4690, 1164, 29962 is [PYTHON]