Multi-Level Optimal Transport for Universal Cross-Tokenizer Knowledge Distillation on Language Models
The paper has been accepted as AAAI 2025 oral.
Teacher models can be downloaded from Hugging Face. And then you can download them in :
$HOME/models/
Llama2-7b-chat-hf: meta-llama/Llama-2-7b-chat-hf
Meta-Llama-3-8B-Instruct: meta-llama/Meta-Llama-3-8B-Instruct
Meta-Llama-3.1-8B-Instruct: meta-llama/Meta-Llama-3.1-8B-Instruct
Mistral-7B-Instruct-v0.3: mistralai/Mistral-7B-Instruct-v0.3
Qwen-7B-Chat: Qwen/Qwen-7B-Chat
Qwen1.5-7B-Chat: Qwen/Qwen1.5-7B-Chat
Student models can be downloaded from Hugging Face. And then you can download them in :
$HOME/llm-recipes/EleutherAI/
pythia-160m: EleutherAI/pythia-160m
opt-350m: facebook/opt-350m
pythia-410m: EleutherAI/pythia-410m
bloomz-560m: bigscience/bloomz-560m (You had better set batchsize=1 for dialogsum or fairytale if you only use a single A100-80G.)
The distilled student model for each task reported in the paper can be downloaded using the following link: https://drive.google.com/drive/folders/1O6k6THm_PjqNybDixppXhad0Nyk-xIjB?usp=drive_link & https://drive.google.com/drive/folders/1ZE_wu0Ey2KpKrjq3NA0VgAvyhynOR6a4?usp=sharing
For distillation, several parameters can be set:
--model_name
: The ID of the student model (HuggingFace repository ID).--lr
: Learning rate for the training process.--num_epochs
: Number of epochs for training.--batch_size_training
: Batch size for training.--val_batch_size
: Batch size for validation.--dataset.file
: Path to the dataset file.--output_dir
: Directory to save the output.--distillation
: Activate distillation.--distillation_config.model_name
: The ID of the teacher model (HuggingFace repository ID).--distillation_config.enable_fsdp
: Enable Fully Sharded Data Parallelism (FSDP).--distillation_config.pure_bf16
: Use pure BF16 precision.--distillation_config.distil_factor
: Factor for distillation loss.--save_step
: Interval for saving checkpoints during training.--encoder_decoder
: Specify this parameter if the student model follows an encoder-decoder architecture.--f
: Choose the method. f=1: ours (fast); f=2: ours (greedy).
Below is an example bash command for running the distillation process:
#export HOME = ""
export CUDA_VISIBLE_DEVICES=0 python finetuning.py \
--model_name $HOME/llm-recipes/EleutherAI/pythia-410m \
--dataset.file $HOME/llm-recipes/llm_distillation/datasets/loader/qed.py \
--lr 1e-6 \
--num_epochs 5 \
--batch_size_training 2 \
--val_batch_size 2 \
--output_dir $HOME/llm-recipes/output2 \
--distillation_config_model_name $HOME/models/meta-llama/Llama-2-7b-chat-hf \
--distillation \
--distillation_config_enable_fsdp \
--distillation_config_pure_bf16 \
--distillation_config_distil_factor 1.5 \
--save_step 2000 \
--f 1
Most of the datasets file have been given in "llm-recipes/llm_distillation/datasets/hf/ "and "llm-recipes/llm_distillation/datasets/hf/processed/" .
Dialogsum: knkarthick/dialogsum
FairytaleQA: WorkInTheDark/FairytaleQA
You need to transfer the dataset files into arrow(stream). We supply transfer.py as an example in llm_distillation/datasets/hf/fairyjsonbase/ .
And if you need to add teacher models' answer as student models' label, you also need to transfer the original dataset into a new arrow dataset with the answer generated by teacher models. We use result.sh in llm-recipes/ and benchmark.py in llm-recipes/llm_distillation/benchmark/ to generate a json file with the answer. And the use the transfer.py in all datasets named like qedllama. Then pay attention to the corresponding benchmark.py in "llm-recipes/llm_distillation/benchmark/" or loader files in "llm-recipes/llm_distillation/datasets/loader/"
You can use results.sh in "llm-recipes/" to eval a teacher model or student model whether it has been distillated or not . And save the prediction answers in a json file.
For example:
#export HOME=
export CUDA_VISIBLE_DEVICES=0 python $HOME/llm-recipes/llm_distillation/benchmark/benchmark619.py \
--model_id "$HOME/llm-recipes/results/output-qedllama-opt" \
--model_tokenizer "$HOME/llm-recipes/EleutherAI/opt-350m" \
--dataset_id "$HOME/llm-recipes/llm_distillation/datasets/processed/qed" \
--split_name "validation" \
--context \
--title \
--batch_size 1 \
--num_workers 1 \
--output_path "$HOME/llm-recipes/test/" \
--number_few_shot 0 \
--context_length 1024 \
--from_disk \
--task "qa" \
--f 1 \
--save_predictions
All these files use "{os.getenv('HOME')}"
llm_distillation/datasets/generator.py
llm_distillation/datasets/loader/*
llm_distillation/prompt/prompt.py
llm_distillation/benchtestfairy.py
Llm_distillation/benchmark/*
If you meet errors on you machine because of environmental errors, you may try to change them into direct path.
@article{cui2024multi,
title={Multi-Level Optimal Transport for Universal Cross-Tokenizer Knowledge Distillation on Language Models},
author={Cui, Xiao and Zhu, Mo and Qin, Yulei and Xie, Liang and Zhou, Wengang and Li, Houqiang},
journal={arXiv preprint arXiv:2412.14528},
year={2024}
}