Skip to content

Commit

Permalink
Add Text2GQL fine tuning framework and provide TuGraph examples (#287)
Browse files Browse the repository at this point in the history
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: lorcanzhang <lorcanzhang@tencent.com>
Co-authored-by: luchun <71970539+zhanghy-sketchzh@users.noreply.github.com>
Co-authored-by: csunny <cfqsunny@163.com>
  • Loading branch information
5 people authored Sep 1, 2024
1 parent 1c50727 commit e80d34a
Show file tree
Hide file tree
Showing 57 changed files with 19,621 additions and 0 deletions.
17 changes: 17 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,36 @@ data/eval
output_pred/
wandb/
src/dbgpt-hub-sql/dbgpt_hub_sql/data/*
src/dbgpt-hub-gql/dbgpt_hub_gql/data/*
src/dbgpt-hub-sql/codellama/*
src/dbgpt-hub-gql/codellama/*
src/dbgpt-hub-sql/wandb/*
src/dbgpt-hub-gql/wandb/*
# But track the data/eval_data folder itself
!src/dbgpt-hub-sql/dbgpt_hub_sql/data/eval_data/
!src/dbgpt-hub-sql/dbgpt_hub_sql/data/dataset_info.json
!src/dbgpt-hub-sql/dbgpt_hub_sql/data/example_text2sql.json
!src/dbgpt-hub-gql/dbgpt_hub_gql/data/tugraph-db-example
!src/dbgpt-hub-gql/dbgpt_hub_gql/data/dataset_info.json
!src/dbgpt-hub-gql/dbgpt_hub_gql/data/example_text2sql.json

# Ignore everything under dbgpt_hub_sql/ouput/ except the adapter directory
src/dbgpt-hub-sql/dbgpt_hub_sql/output/
src/dbgpt-hub-sql/dbgpt_hub_sql/output/adapter/*
!src/dbgpt-hub-sql/dbgpt_hub_sql/output/adapter/.gitkeep
src/dbgpt-hub-sql/dbgpt_hub_sql/output/logs/*
!src/dbgpt-hub-sql/dbgpt_hub_sql/output/logs/.gitkeep
src/dbgpt-hub-sql/dbgpt_hub_sql/output/pred/*
!src/dbgpt-hub-sql/dbgpt_hub_sql/output/pred/.gitkeep

src/dbgpt-hub-gql/dbgpt_hub_gql/output/
src/dbgpt-hub-gql/dbgpt_hub_gql/output/adapter/*
!src/dbgpt-hub-gql/dbgpt_hub_gql/output/adapter/.gitkeep
src/dbgpt-hub-gql/dbgpt_hub_gql/output/logs/*
!src/dbgpt-hub-gql/dbgpt_hub_gql/output/logs/.gitkeep
src/dbgpt-hub-gql/dbgpt_hub_gql/output/pred/*
!src/dbgpt-hub-gql/dbgpt_hub_gql/output/pred/.gitkeep

# Ignore NLU output
src/dbgpt-hub-nlu/output
src/dbgpt-hub-nlu/data
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

## 🔥🔥🔥 News
- Support [Text2NLU](src/dbgpt-hub-nlu/README.zh.md) fine-tuning to improve semantic understanding accuracy.
- Support [Text2GQL](src/dbgpt-hub-gql/README.zh.md) fine-tuning to generate graph query.

## Baseline

Expand Down
1 change: 1 addition & 0 deletions README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

## 🔥🔥🔥 News
- 支持 [Text2NLU](src/dbgpt-hub-nlu/README.zh.md) 微调, 提升意图识别准确率。
- 支持 [Text2GQL](src/dbgpt-hub-gql/README.zh.md)微调,可以通过自然语言生成图查询语句。

## Baseline
- 更新日期: 2023/12/08
Expand Down
236 changes: 236 additions & 0 deletions src/dbgpt-hub-gql/README.zh.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# DB-GPT-GQL:利用LLMs实现Text-to-GQL

## Baseline

- 更新日期: 2024/08/26

<table style="text-align: center;">
<tr>
<th style="text-align: center;">Language</th>
<th style="text-align: center;">Dataset</th>
<th style="text-align: center;">Model</th>
<th>Method</th>
<th>Similarity</th>
<th>Grammar</th>
</tr>
<tr >
<td></td>
<td></td>
<td></td>
<td>base</td>
<td>0.769</td>
<td>0.703</td>
</tr>
<tr>
<td>Cypher <a href="https://github.com/TuGraph-family/tugraph-db">(tugraph-db)</a></td>
<td><a href="https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/tugraph-db/tugraph-db.zip">TuGraph-DB Cypher数据集</a></td>
<td><a href="https://huggingface.co/tugraph/CodeLlama-7b-Cypher-hf/tree/1.0">CodeLlama-7B-Instruct</a></td>
<td>lora</td>
<td>0.928</td>
<td>0.946</td>
</tr>
<tr >
<td></td>
<td></td>
<td></td>
<td>base</td>
<td>0.493</td>
<td>0.002</td>
</tr>
<tr>
<td>GQL<a href="https://github.com/TuGraph-family/tugraph-analytics">(tugraph-analytics)</a></td>
<td><a href="https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/tugraph-analytics/tugraph-analytics.zip">TuGraph-Analytics GQL数据集</a></td>
<td><a href="https://huggingface.co/tugraph/CodeLlama-7b-GQL-hf/tree/1.1">CodeLlama-7B-Instruct</a></td>
<td>lora</td>
<td>0.935</td>
<td>0.984</td>
</tr>
</table>

## Contents
- [DB-GPT-GQL](#db-gpt-gql利用llms实现text-to-gql)
- [Baseline](#baseline)
- [Contents](#contents)
- [一、简介](#一简介)
- [二、Text-to-GQL微调](#二text-to-gql微调)
- [2.1、数据集](#21数据集)
- [2.2、基座模型](#22基座模型)
- [三、使用方法](#三使用方法)
- [3.1、环境准备](#31环境准备)
- [3.2、模型准备](#32模型准备)
- [3.3、模型微调](#33模型微调)
- [3.4、模型预测](#34模型预测)
- [3.5、模型评估](#35模型评估)
- [3.5.1、文本相似度评估](#351文本相似度评估)
- [3.5.2、语法正确性评估](#352语法正确性评估)
- [3.6、模型权重合并](#36模型权重合并)

# 一、简介

DB-GPT-GQL是一个面向图数据库查询语言的,利用LLMs实现Text-to-GQL翻译的模块。主要包含模型微调、Text-to-GQL预测,预测结果评估等步骤。关系型数据库领域的Text-to-SQL翻译任务发展至如今已经拥有了大量的数据集以及多维度的翻译结果评估流程。不同于已经逐渐成熟的Text-to-SQL翻译任务,Text-to-GQL翻译任务由于图查询语言缺乏统一规范、目标成为国际标准的ISOGQL尚未真正落地等原因,使得建立属于各类图查询语言的完整语料数据集和建立Text-to-GQL翻译结果评估机制成为了两个颇具挑战性的任务。

DB-GPT-GQL不仅支持了基于多个大模型的微调、预测流程,在翻译结果评估方面也提供了两种评估方式,第一种是基于翻译结果与标准答案之间近似程度的文本相似度评估,这一评估方式适用于任何图查询语言,第二种则是基于不同图查询语言的语法解析器对翻译结果进行语法解析的语法正确性评估,目前已支持tugraph-db与tugraph-analytics两个数据库的图查询语言。

未来DB-GPT-GQL将会实现基于翻译结果的执行计划正确性评估(不需要实际数据),以及更进一步的执行正确性评估(需要实际数据),并参考Text-to-SQL领域的spider数据集中的方法实现对数据集中查询语言复杂程度的分级。

## 二、Text-to-GQL微调

我们基于大语言模型的SFT来提升Text-to-GQL的效果。

### 2.1、数据集

本项目样例数据集为`Cypher(tugraph-db)`,其中包含tugraph-db提供的,可在tugraph-db上可执行的185条语料,存放在`/dbgpt_hub_gql/data/tugraph-db-example`文件夹中,当前可使用的数据集如下:

- [Cypher(tugraph-db)](https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/tugraph-db/tugraph-db.zip): 符合tugraph-db的Cypher语法的数据集,采用“ [语法制导的语料生成策略](https://mp.weixin.qq.com/s/rZdj8TEoHZg_f4C-V4lq2A)”,将查询语言模板结合多样化的schema生成查询语言,并使用大模型泛化与之对应的自然语言问题描述后生成的语料。[语料生成框架](https://github.com/TuGraph-contrib/Awesome-Text2GQL)现已开源,欢迎参与共建。

- [GQL(tugraph-analytics)](https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/tugraph-analytics/tugraph-analytics.zip): 符合tugraph-analytics的GQL语法的数据集,采用“ [语法制导的语料生成策略](https://mp.weixin.qq.com/s/rZdj8TEoHZg_f4C-V4lq2A)”,将查询语言模板结合多样化的schema生成查询语言,并使用大模型泛化与之对应的自然语言问题描述后生成的语料。

请将下载并解压完成后的数据集放置在`/dbgpt_hub_gql/data/`目录下使用。

### 2.2、基座模型

DB-GPT-GQL目前已经支持的base模型有:

- [x] CodeLlama
- [x] Baichuan2
- [x] LLaMa/LLaMa2
- [x] Falcon
- [x] Qwen
- [x] XVERSE
- [x] ChatGLM2
- [x] ChatGLM3
- [x] internlm
- [x] Falcon
- [x] sqlcoder-7b(mistral)
- [x] sqlcoder2-15b(starcoder)



模型可以基于quantization_bit为4的量化微调(QLoRA)所需的最低硬件资源,可以参考如下:

| 模型参数 | GPU RAM | CPU RAM | DISK |
| -------- | ------- | ------- | ------ |
| 7b | 6GB | 3.6GB | 36.4GB |
| 13b | 13.4GB | 5.9GB | 60.2GB |

其中相关参数均设置的为最小,batch_size为1,max_length为512。根据经验,如果计算资源足够,为了效果更好,建议相关长度值设置为1024或者2048。

## 三、使用方法
本章将以仓库中自带的`/dbgpt_hub_gql/data/tugraph-db-example`文件夹中共计185条语料以及`CodeLlama-7B-Instruct`(需要额外下载)为样例,展示DB-GPT-GQL的完整功能。

### 3.1、环境准备

克隆项目并创建 conda 环境,
```bash
git clone https://github.com/eosphoros-ai/DB-GPT-Hub.git
cd DB-GPT-Hub
conda create -n dbgpt_hub_gql python=3.10
conda activate dbgpt_hub_gql
```

进入DB-GPT-GQL项目目录,并使用poetry安装依赖
```bash
cd src/dbgpt-hub-gql
pip install -e .
```

### 3.2、模型准备
创建并进入codellama模型存放目录
```bash
mkdir codellama
cd ./codellama
```

`codellama`文件夹下创建`download.py`文件并将如下内容复制进入python文件中
```python
from modelscope import snapshot_download

model_dir = snapshot_download("AI-ModelScope/CodeLlama-7b-Instruct-hf")
```

安装python依赖并下载模型
```bash
pip install modelscope
python3 download.py
```

下载完成后,将模型文件软链接到`codellama`目录下
```bash
ln -s /root/.cache/modelscope/hub/AI-ModelScope/CodeLlama-7b-Instruct-hf ./
```

### 3.3、模型微调
开始语料微调前需要手动将训练数据集在`./dbgpt_hub_gql/data/dataset_info.json`中注册,`./dbgpt_hub_gql/data/tugraph-db-example`文件夹中的训练数据集已经注册在其中,格式如下

```json
"tugraph_db_example_train": {
"file_name": "./tugraph-db-example/train.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
}
```

`dbgpt_hub_gql/scripts/train_sft.sh`中设置数据集,模型,以及微调结果输出路径,当前默认值适配`./dbgpt_hub_gql/data/tugraph-db-example`中的训练数据集以及`CodeLlama-7B-Instruct`模型,使用LoRA方法微调

```shell script
dataset="tugraph_db_example_train"
model_name_or_path=${model_name_or_path-"codellama/CodeLlama-7b-Instruct-hf"}
output_dir="dbgpt_hub_gql/output/adapter/CodeLlama-7b-gql-lora"
```

运行微调脚本,开始微调
```bash
sh dbgpt_hub_gql/scripts/train_sft.sh
```

### 3.4、模型预测

`./dbgpt_hub_gql/scripts/predict_sft.sh`中设置需要预测的数据集,模型,模型微调结果路径以及预测结果路径,当前默认值适配`./dbgpt_hub_gql/data/tugraph-db-example`中的开发数据集以及LoRA方法微调后的`CodeLlama-7B-Instruct`模型

```shell script
CUDA_VISIBLE_DEVICES=0,1 python dbgpt_hub_gql/predict/predict.py \
--model_name_or_path codellama/CodeLlama-7b-Instruct-hf \
--template llama2 \
--finetuning_type lora \
--predicted_input_filename dbgpt_hub_gql/data/tugraph-db-example/dev.json \
--checkpoint_dir dbgpt_hub_gql/output/adapter/CodeLlama-7b-gql-lora \
--predicted_out_filename dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt >> ${pred_log}
```

运行预测脚本,获取预测结果

```bash
sh dbgpt_hub_gql/scripts/predict_sft.sh
```

### 3.5、模型评估

目前版本支持两种预测结果评估方法,第一种是基于Jaro–Winkler distance的文本相似度评估,第二种是基于`.g4`语法文件或图数据库现有语法解析器的语法正确性评估。

#### 3.5.1、文本相似度评估

文本相似度评估直接统计预测结果与标准结果的Jaro–Winkler distance

```bash
python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/gold_dev.txt --etype similarity
```

#### 3.5.2、语法正确性评估

`tugraph-db-example`是符合`tugraph-db`的LCypher语法规则的语料数据集,语法正确性评估使用ANTLR4工具,基于`./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/Lcypher.g4`文件生成了语法解析器,用于评估模型预测结果的语法正确性。

```bash
python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/gold_dev.txt --etype grammar --impl tugraph-db
```

### 3.6、模型权重合并

如果你需要将训练的基础模型和微调的Peft模块的权重合并,导出一个完整的模型。则运行如下模型导出脚本:

```bash
sh dbgpt_hub_gql/scripts/export_merge.sh
```
Empty file.
9 changes: 9 additions & 0 deletions src/dbgpt-hub-gql/dbgpt_hub_gql/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .data_args import DataArguments, Llama2Template
from .model_args import ModelArguments, TrainingArguments

__all__ = [
"DataArguments",
"Llama2Template",
"ModelArguments",
"TrainingArguments",
]
53 changes: 53 additions & 0 deletions src/dbgpt-hub-gql/dbgpt_hub_gql/configs/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os

### path config
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# ROOT_PATH = "/root/autodl-tmp"
# MODELS_PARENT_PATH = "/home/model_files/codellama/"
# DEFAULT_FT_MODEL_NAME = "CodeLlama-7b-Instruct-hf"
MODELS_PARENT_PATH = "/home/model/"
DEFAULT_FT_MODEL_NAME = "Baichuan2-13B-Chat"
MODEL_PATH = os.path.join(MODELS_PARENT_PATH, DEFAULT_FT_MODEL_NAME)

# MODEL_PATH = os.path.join(ROOT_PATH, "model")
ADAPTER_PATH = os.path.join(ROOT_PATH, "dbgpt_hub_gql/output/adapter")
MERGED_MODELS = os.path.join(ROOT_PATH, "dbgpt_hub_gql/output/merged_models")

# DATA_PATH = "/root/autodl-tmp/data/spider/pre_processed_data"
# OUT_DIR= "/root/autodl-tmp/codellama"

DATA_PATH = os.path.join(ROOT_PATH, "dbgpt_hub_gql/data")
PREDICTED_DATA_PATH = os.path.join(
ROOT_PATH, "dbgpt_hub_gql/data/tugraph-db-example/dev.json"
)
PREDICTED_OUT_FILENAME = "pred_gql.txt"
# OUT_DIR = os.path.join(DATA_PATH, "out_pred")
OUT_DIR = os.path.join(ROOT_PATH, "dbgpt_hub_gql/output/")

## model constants
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"


LOG_FILE_NAME = "trainer_log.jsonl"

# head_state_dict,model save name
VALUE_HEAD_FILE_NAME = "value_head.bin"

# output ,finetuning_args save_to_json name
FINETUNING_ARGS_NAME = "finetuning_args.json"

# when prepare_model_for_training ,layer_norm_names
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
EXT2TYPE = {"csv": "csv", "json": "json", "jsonl": "json", "txt": "text"}

INSTRUCTION_PROMPT = """\
I want you to act as a GQL terminal in front of an example database, \
you need only to return the gql command to me.Below is an instruction that describes a task, \
Write a response that appropriately completes the request.\n"
##Instruction:\n{}\n"""
INPUT_PROMPT = "###Input:\n{}\n\n###Response:"
Loading

0 comments on commit e80d34a

Please sign in to comment.