forked from tloen/alpaca-lora
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport_hf_checkpoint_cmd.py
63 lines (48 loc) · 1.72 KB
/
export_hf_checkpoint_cmd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import argparse
import torch
import transformers
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer
# Parse command-line arguments
parser = argparse.ArgumentParser(description='Script for processing BASE_MODEL and LORA_MODEL.')
parser.add_argument('--base_model', required=True, help='Specify the BASE_MODEL value')
parser.add_argument('--lora_model', required=True, help='Specify the LORA_MODEL value')
parser.add_argument('--max_shard_size', default='10GB', help='Specify the maximum shard size (default: 10GB)')
args = parser.parse_args()
BASE_MODEL = args.base_model
LORA_MODEL = args.lora_model
MAX_SHARD_SIZE = args.max_shard_size
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
base_model = LlamaForCausalLM.from_pretrained(
BASE_MODEL,
load_in_8bit=False,
torch_dtype=torch.float16,
device_map={"": "cpu"},
)
first_weight = base_model.model.layers[0].self_attn.q_proj.weight
first_weight_old = first_weight.clone()
lora_model = PeftModel.from_pretrained(
base_model,
LORA_MODEL,
device_map={"": "cpu"},
torch_dtype=torch.float16,
)
lora_weight = lora_model.base_model.model.model.layers[
0
].self_attn.q_proj.weight
assert torch.allclose(first_weight_old, first_weight)
# merge weights - new merging method from peft
lora_model = lora_model.merge_and_unload()
lora_model.train(False)
# did we do anything?
assert not torch.allclose(first_weight_old, first_weight)
lora_model_sd = lora_model.state_dict()
deloreanized_sd = {
k.replace("base_model.model.", ""): v
for k, v in lora_model_sd.items()
if "lora" not in k
}
LlamaForCausalLM.save_pretrained(
base_model, "./hf_ckpt", state_dict=deloreanized_sd, max_shard_size=MAX_SHARD_SIZE
)