Skip to content

Pytorch Implementation of "Multi-Level Optimal Transport for Universal Cross-Tokenizer Knowledge Distillation on Language Models", AAAI 2025

Notifications You must be signed in to change notification settings

2018cx/Multi-Level-OT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Multi-Level Optimal Transport for Universal Cross-Tokenizer Knowledge Distillation on Language Models

The paper has been accepted as AAAI 2025 oral.

Download Pre-trained Teacher Models

Teacher models can be downloaded from Hugging Face. And then you can download them in :

$HOME/models/

Llama2-7b-chat-hf: meta-llama/Llama-2-7b-chat-hf

Meta-Llama-3-8B-Instruct: meta-llama/Meta-Llama-3-8B-Instruct

Meta-Llama-3.1-8B-Instruct: meta-llama/Meta-Llama-3.1-8B-Instruct

Mistral-7B-Instruct-v0.3: mistralai/Mistral-7B-Instruct-v0.3

Qwen-7B-Chat: Qwen/Qwen-7B-Chat

Qwen1.5-7B-Chat: Qwen/Qwen1.5-7B-Chat

Download Pre-trained Student Models:

Student models can be downloaded from Hugging Face. And then you can download them in :

$HOME/llm-recipes/EleutherAI/

pythia-160m: EleutherAI/pythia-160m

opt-350m: facebook/opt-350m

pythia-410m: EleutherAI/pythia-410m

bloomz-560m: bigscience/bloomz-560m (You had better set batchsize=1 for dialogsum or fairytale if you only use a single A100-80G.)

Student Checkpoints

The distilled student model for each task reported in the paper can be downloaded using the following link: https://drive.google.com/drive/folders/1O6k6THm_PjqNybDixppXhad0Nyk-xIjB?usp=drive_link & https://drive.google.com/drive/folders/1ZE_wu0Ey2KpKrjq3NA0VgAvyhynOR6a4?usp=sharing

Task-specific Student Model Distillation

For distillation, several parameters can be set:

  • --model_name: The ID of the student model (HuggingFace repository ID).
  • --lr: Learning rate for the training process.
  • --num_epochs: Number of epochs for training.
  • --batch_size_training: Batch size for training.
  • --val_batch_size: Batch size for validation.
  • --dataset.file: Path to the dataset file.
  • --output_dir: Directory to save the output.
  • --distillation: Activate distillation.
  • --distillation_config.model_name: The ID of the teacher model (HuggingFace repository ID).
  • --distillation_config.enable_fsdp: Enable Fully Sharded Data Parallelism (FSDP).
  • --distillation_config.pure_bf16: Use pure BF16 precision.
  • --distillation_config.distil_factor: Factor for distillation loss.
  • --save_step: Interval for saving checkpoints during training.
  • --encoder_decoder: Specify this parameter if the student model follows an encoder-decoder architecture.
  • --f: Choose the method. f=1: ours (fast); f=2: ours (greedy).

Example

Below is an example bash command for running the distillation process:

#export HOME = ""

export CUDA_VISIBLE_DEVICES=0 python finetuning.py \
--model_name $HOME/llm-recipes/EleutherAI/pythia-410m \
--dataset.file $HOME/llm-recipes/llm_distillation/datasets/loader/qed.py \
--lr 1e-6 \
--num_epochs 5 \
--batch_size_training 2 \
--val_batch_size 2 \
--output_dir $HOME/llm-recipes/output2 \
--distillation_config_model_name $HOME/models/meta-llama/Llama-2-7b-chat-hf \
--distillation \
--distillation_config_enable_fsdp \
--distillation_config_pure_bf16 \
--distillation_config_distil_factor 1.5 \
--save_step 2000 \
--f 1

Dataset File

Most of the datasets file have been given in "llm-recipes/llm_distillation/datasets/hf/ "and "llm-recipes/llm_distillation/datasets/hf/processed/" .

Dialogsum: knkarthick/dialogsum

FairytaleQA: WorkInTheDark/FairytaleQA

You need to transfer the dataset files into arrow(stream). We supply transfer.py as an example in llm_distillation/datasets/hf/fairyjsonbase/ .

And if you need to add teacher models' answer as student models' label, you also need to transfer the original dataset into a new arrow dataset with the answer generated by teacher models. We use result.sh in llm-recipes/ and benchmark.py in llm-recipes/llm_distillation/benchmark/ to generate a json file with the answer. And the use the transfer.py in all datasets named like qedllama. Then pay attention to the corresponding benchmark.py in "llm-recipes/llm_distillation/benchmark/" or loader files in "llm-recipes/llm_distillation/datasets/loader/"

Evaluation

You can use results.sh in "llm-recipes/" to eval a teacher model or student model whether it has been distillated or not . And save the prediction answers in a json file.

For example:

#export HOME=

export CUDA_VISIBLE_DEVICES=0 python $HOME/llm-recipes/llm_distillation/benchmark/benchmark619.py \
  --model_id "$HOME/llm-recipes/results/output-qedllama-opt" \
  --model_tokenizer "$HOME/llm-recipes/EleutherAI/opt-350m" \
  --dataset_id "$HOME/llm-recipes/llm_distillation/datasets/processed/qed" \
  --split_name "validation" \
  --context \
  --title \
  --batch_size 1 \
  --num_workers 1 \
  --output_path "$HOME/llm-recipes/test/" \
  --number_few_shot 0 \
  --context_length 1024 \
  --from_disk \
  --task "qa" \
  --f 1 \
  --save_predictions

Environmental statement

All these files use "{os.getenv('HOME')}"

llm_distillation/datasets/generator.py

llm_distillation/datasets/loader/*

llm_distillation/prompt/prompt.py

llm_distillation/benchtestfairy.py

Llm_distillation/benchmark/*

If you meet errors on you machine because of environmental errors, you may try to change them into direct path.

Citation

@article{cui2024multi,
  title={Multi-Level Optimal Transport for Universal Cross-Tokenizer Knowledge Distillation on Language Models},
  author={Cui, Xiao and Zhu, Mo and Qin, Yulei and Xie, Liang and Zhou, Wengang and Li, Houqiang},
  journal={arXiv preprint arXiv:2412.14528},
  year={2024}
}

About

Pytorch Implementation of "Multi-Level Optimal Transport for Universal Cross-Tokenizer Knowledge Distillation on Language Models", AAAI 2025

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published