-
Notifications
You must be signed in to change notification settings - Fork 1
/
main_zh.py
88 lines (71 loc) · 3.7 KB
/
main_zh.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import argparse
parser = argparse.ArgumentParser(description='Parser for APC-based DPO.')
parser.add_argument('--api_key', type=str)
parser.add_argument('--hf_token', type=str)
parser.add_argument('--character', type=str)
parser.add_argument('--model_engine', type=str)
parser.add_argument('--use_pretrained_discriminator', action='store_true')
parser.add_argument('--relevance_finetune_epoch', type=int)
parser.add_argument('--rag_top_k', type=int)
parser.add_argument('--nli_finetune_epoch', type=int)
parser.add_argument('--max_dpo_data', type=int)
parser.add_argument('--lora_rank', type=int)
parser.add_argument('--prp_dpo_epoch', type=int)
parser.add_argument('--prp_scale', type=str)
parser.add_argument('--device', type=str)
args = parser.parse_args()
import os
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
os.environ["WANDB_DISABLED"] = "true"
os.environ["HF_TOKEN"] = args.hf_token
import openai
import json
import re
import os
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from datasets import Dataset
import torch
from torch import nn
from torch.optim import AdamW
from peft import LoraConfig, PeftModel, PeftConfig, get_peft_model, get_peft_model_state_dict
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from classifier_zh import Classifier, get_relevance_discriminator_zh, get_nli_discriminator_zh
from prompt import (convert_to_statement_zh, build_relevant_query_dataset_zh, build_statement_query_relevance_dataset_zh,
build_statement_to_response_nli_dataset_zh, discriminate_statement_to_response_nli_dataset_zh)
from score import score_APC
from prp_model import load_generator, generate_rag_dpo_dataset_zh, train_prp
if not os.path.exists("statement"):
os.makedirs("statement")
api_key = args.api_key
character = args.character
model_engine = args.model_engine
use_pretrained_discriminator = args.use_pretrained_discriminator
relevance_finetune_epoch = args.relevance_finetune_epoch
rag_top_k = args.rag_top_k
nli_finetune_epoch = args.nli_finetune_epoch
max_dpo_data = args.max_dpo_data
lora_rank = args.lora_rank
prp_dpo_epoch = args.prp_dpo_epoch
prp_scale = args.prp_scale
openai.api_key = api_key
# Stage 1: Dataset Synthesis
persona_statement_dataset = convert_to_statement_zh(character, model_engine)
relevant_query_dataset = build_relevant_query_dataset_zh(character, persona_statement_dataset, model_engine)
if not use_pretrained_discriminator:
statement_query_relevance_dataset = build_statement_query_relevance_dataset_zh(character, relevant_query_dataset, model_engine)
statement_to_response_nli_dataset = build_statement_to_response_nli_dataset_zh(character, relevant_query_dataset, model_engine)
statement_to_response_nli_v2_dataset = discriminate_statement_to_response_nli_dataset_zh(character, statement_to_response_nli_dataset, model_engine)
else:
statement_query_relevance_dataset = None
statement_to_response_nli_dataset = None
statement_to_response_nli_v2_dataset = None
# Stage 2: Discriminator Fine-tuning
relevance_discriminator = get_relevance_discriminator_zh(character, statement_query_relevance_dataset, relevance_finetune_epoch, use_pretrained_discriminator)
nli_discriminator = get_nli_discriminator_zh(character, statement_to_response_nli_v2_dataset, nli_finetune_epoch, use_pretrained_discriminator)
# Stage 3: APC-based DPO
prp_tokenizer, prp_model = load_generator(prp_scale)
rag_dpo_dataset = generate_rag_dpo_dataset_zh(character, prp_model, prp_tokenizer, relevance_discriminator, nli_discriminator, persona_statement_dataset, relevant_query_dataset, max_dpo_data, rag_top_k)
prp_tokenizer, prp_model = train_prp(character, prp_model, prp_tokenizer, prp_scale, rag_dpo_dataset, lora_rank, prp_dpo_epoch)