-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
118 lines (98 loc) · 3.93 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
from typing import Union
from transformers import EsmTokenizer, EsmForMaskedLM, EsmConfig
from transformers import TrainingArguments, Trainer
from transformers import LlamaForCausalLM, LlamaConfig
import sys
sys.path.append('./utils')
from utils import multimodal_dataset
_10TB = 10995116277760
def train(ckpt):
train_struct_name = "/cto_studio/xtalpi_lab/Datasets/af_swissprot_vqvae.pkl"
# train_lmdb_path = "/cto_studio/xtalpi_lab/temp/lmdb/train_dedup/data.lmdb"
# valid_lmdb_path = "/cto_studio/xtalpi_lab/temp/lmdb/valid/data.lmdb"
# output_dir = "./results"
train_lmdb_path = "/cto_labs/liuzijing/lmdb/train_dedup/data.lmdb"
valid_lmdb_path = "/cto_labs/liuzijing/lmdb/valid/data.lmdb"
output_dir = "/cto_labs/liuzijing/outputs/gpt2small_ss2"
struct_only = False
batch_size = 128
gradient_accumulation = 1
seq_ratio = 1# struct_data : seq_data = 1: (seq_ratio-1)
if struct_only:
exp_name = f"struct_only_b{batch_size}_{seq_ratio}"
else:
exp_name = f"seq_struct_b{batch_size}_ss{seq_ratio}"
train_dataset = multimodal_dataset.SeqStructDataset(train_lmdb_path,
train_struct_name,
max_length=1024, seq_ratio=seq_ratio,
struct_only=struct_only)
test_dataset = multimodal_dataset.SeqStructDataset(valid_lmdb_path,
train_struct_name,
max_length=1024, struct_only=struct_only)
# config = EsmConfig.from_pretrained(model_checkpoint)
# model = EsmForMaskedLM.from_pretrained(model_checkpoint)
# model = EsmForMaskedLM(config)
configuration = LlamaConfig()
configuration.finetuning_task = exp_name
configuration.pad_token_id = train_dataset.sequence_tokenizer.pad_token_id
## num para ~ 12 * Hidden^2 * layer + Vocab_size * hidden
## progen2 small L=12 Head=16 hidden=1024; gpt2 small L=12 Head=12 hidden=64*Head=768
## gpt2 middle L=16 Head=16 hidden=1024 200m
## gpt2 middle L=20 Head=16 hidden=1024
configuration.hidden_size = 768
configuration.intermediate_size = 768*4
configuration.max_position_embeddings = 1028##
configuration.num_attention_heads = 12
configuration.num_hidden_layers = 12
configuration.num_key_value_heads = 12
configuration.vocab_size = 4096 + 5 + 33##
configuration.bos_token_id = 0
configuration.use_cache = False
model = LlamaForCausalLM(configuration)
gradient_checkpointing = True
save_steps = 5000
eval_steps = 5000
save_total_limit=5
args = TrainingArguments(
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation, # 2 if 4 gpus
warmup_steps=5000,
num_train_epochs=100,
# max_steps=500000,
learning_rate=4e-4,
fp16=True,
logging_steps=1000,
optim="adamw_torch",
evaluation_strategy="steps",
save_strategy="steps",
eval_steps=eval_steps,
save_steps=save_steps,
output_dir=output_dir,
save_total_limit=save_total_limit,
load_best_model_at_end=True,
# ddp_find_unused_parameters=True,
report_to="tensorboard",
run_name=None,
dataloader_num_workers=0,
gradient_checkpointing=gradient_checkpointing,
data_seed=54,
# use_cpu=True
)
trainer = Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
data_collator=multimodal_dataset.collate_fn_gpt
)
if ckpt is None:
trainer.train()
else:
trainer.train(resume_from_checkpoint=ckpt)
if __name__ == "__main__":
if len(sys.argv) == 1:
ckpt = None
elif len(sys.argv) == 2:
ckpt = sys.argv[1]
train(ckpt)