-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
executable file
·122 lines (108 loc) · 4.73 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from pytorch_lightning import Trainer, seed_everything
from factcg.dataloader import AlignmentDataLoader
from factcg.grounding_model import GroundingModel
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from argparse import ArgumentParser
import os
import torch
import warnings
warnings.simplefilter(action='ignore')
def train(datasets, args):
dm = AlignmentDataLoader(
dataset_config=datasets,
model_name=args.model_name,
sample_mode='seq',
train_batch_size=args.batch_size,
eval_batch_size=16,
num_workers=args.num_workers,
train_eval_split=0.95,
)
dm.setup()
if args.ckpt_path != "":
model = GroundingModel.load_from_checkpoint(
args.ckpt_path,
model_name=args.model_name,
adam_epsilon=args.adam_epsilon,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_steps_portion=args.warm_up_proportion
)
else:
model = GroundingModel(
model_name=args.model_name,
adam_epsilon=args.adam_epsilon,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_steps_portion=args.warm_up_proportion
)
checkpoint_name = f"{args.ckpt_comment}{args.model_name.replace('/', '-')}"
checkpoint_callback = ModelCheckpoint(
dirpath=args.ckpt_save_path,
filename=checkpoint_name + "_{epoch:02d}_{step}",
every_n_train_steps=5000,
monitor="train_loss",
save_top_k=5
)
if "t5" in args.model_name.lower():
precision = 32
else:
precision = 16
logger = TensorBoardLogger(
"logs", name=args.ckpt_save_path.replace("/", "").replace(".", ""))
trainer = Trainer(
accelerator = 'gpu',
max_epochs = args.num_epoch,
devices = args.devices,
strategy = "ddp_find_unused_parameters_true",
precision = precision,
callbacks = [checkpoint_callback],
accumulate_grad_batches = args.accumulate_grad_batch,
logger = logger
)
trainer.fit(model, datamodule=dm)
trainer.save_checkpoint(os.path.join(
args.ckpt_save_path, f"{checkpoint_name}_final.ckpt"))
print("Training is finished.")
if __name__ == "__main__":
TRAINING_DATASETS = {
# Stage 1 training
'anli_minicheck': {'task_type': 'bin_grounding', 'data_path': 'anli_minicheck.json'},
'CG2C_hotpot_qa_rbt_mnli_failed': {'task_type': 'bin_grounding', 'data_path': 'CG2C_hotpot_qa_rbt_mnli_failed.json'},
'CG2C_musique_minhop3_rbt_mnli_failed': {'task_type': 'bin_grounding', 'data_path': 'CG2C_musique_minhop3_rbt_mnli_failed.json'},
'minicheck_c2d': {'task_type': 'bin_grounding', 'data_path': 'minicheck_c2d.json'},
# Stage 2 training
'CG2C_doc': {'task_type': 'bin_grounding', 'data_path': 'CG2C_doc.json'},
'minicheck_d2c': {'task_type': 'bin_grounding', 'data_path': 'minicheck_d2c.json'},
}
parser = ArgumentParser()
parser.add_argument('--seed', type=int, default=2024)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--accumulate-grad-batch', type=int, default=1)
parser.add_argument('--num-epoch', type=int, default=3)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--warm-up-proportion', type=float, default=0.06)
parser.add_argument('--adam-epsilon', type=float, default=1e-6)
parser.add_argument('--weight-decay', type=float, default=0)
parser.add_argument('--learning-rate', type=float, default=5e-5)
parser.add_argument('--val-check-interval', type=float, default=1. / 4)
parser.add_argument('--devices', nargs='+', type=int, required=True)
parser.add_argument('--model-name', type=str, default="microsoft/deberta-v3-large")
parser.add_argument('--ckpt-save-path', type=str, required=True)
parser.add_argument('--ckpt-comment', type=str, default="")
parser.add_argument('--training-datasets', nargs='+', type=str, default=list(TRAINING_DATASETS.keys()), choices=list(TRAINING_DATASETS.keys()))
parser.add_argument('--data-path', type=str, required=True)
parser.add_argument('--max-samples-per-dataset', type=int, default=500000)
parser.add_argument('--ckpt-path', type=str, default="")
args = parser.parse_args()
print(args)
seed_everything(args.seed)
datasets = {
name: {
**TRAINING_DATASETS[name],
"size": args.max_samples_per_dataset,
"data_path": os.path.join(args.data_path, TRAINING_DATASETS[name]['data_path'])
}
for name in args.training_datasets
}
train(datasets, args)