-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0bbb505
commit 7fdcfee
Showing
1 changed file
with
372 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,372 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Build Your own MoE with LLaMa3-based Experts" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Install mergoo" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install mergoo" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Selecting Experts:\n", | ||
"\n", | ||
"With mergoo, you can easily build your own MoE-style LLM by integrating LLaMa3-based experts.\n", | ||
"\n", | ||
"In this tutorial, we have used following LLaMa3-based models:\n", | ||
"- [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B): The generic LLaMa3 8b model, provided by Meta team.\n", | ||
"- [Locutusque/Llama-3-Orca-1.0-8B](https://huggingface.co/Locutusque/Llama-3-Orca-1.0-8B): Fine-tuned LLaMa3 8b on [SlimOrca](https://huggingface.co/datasets/Open-Orca/SlimOrca) for enhancing performance in math, coding, and writing.\n", | ||
"- [mlabonne/OrpoLlama-3-8B](https://huggingface.co/mlabonne/OrpoLlama-3-8B): ORPO fine-tuned of LLaMa3 8b on 1k samples of [ORPO dataset](https://huggingface.co/datasets/mlabonne/orpo-dpo-mix-40k).\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"**Preparing Config:**\n", | ||
"- `model_type`: `llama`\n", | ||
"- `num_experts_per_tok`: Total number of active experts at each step. These experts are selected sparsely.\n", | ||
"- `experts`: List of dictionaries of seed models that would get merged. For each expert, `model_id` is mandatory. The model_id can be either a local path or a Huggingface model id.\n", | ||
"- `router_layers`: These are the layer names that would be replaced with MOE layers. Weights of the rest of the layers are aggregated using averaging. In the future, we will support multiple aggregation methods from MergeKit.\n", | ||
"- `router_layers_index`: List of indexes. These are the indexes of transformer blocks, layers of these index would be converted to MOE. Default `router_layers_index` is empty meaning the MOE conversion gets applied on all the layers, given that `router_layers` identifier matches. `[None]` can be used when no MOE layer should be kept following the [BTM](https://arxiv.org/abs/2208.03306) architecture." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"MoE Layer Index : [*]\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00, 1.04it/s]\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"merging expert : unsloth/llama-3-8b\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"100%|██████████| 291/291 [00:06<00:00, 47.77it/s]\n", | ||
"Loading checkpoint shards: 100%|██████████| 9/9 [00:04<00:00, 1.92it/s]\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"merging expert : Locutusque/Llama-3-Orca-1.0-8B\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"100%|██████████| 291/291 [00:06<00:00, 45.51it/s]\n", | ||
"Downloading shards: 100%|██████████| 4/4 [00:29<00:00, 7.47s/it]\n", | ||
"Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00, 1.57it/s]\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"merging expert : mlabonne/OrpoLlama-3-8B\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"100%|██████████| 291/291 [00:06<00:00, 44.84it/s]\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"count_averaged_layers : 195\n", | ||
"count_router_layers : 96\n", | ||
"count_total_router_layers : 288\n", | ||
"The model is bigger than the maximum size per checkpoint (9GB) and is going to be split in 5 checkpoint shards. You can find where each parameters has been saved in the index located at data/llama3_moe/model.safetensors.index.json.\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"checkpoint saved at data/llama3_moe\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import torch\n", | ||
"from mergoo.compose_experts import ComposeExperts\n", | ||
"\n", | ||
"model_id = \"data/checkpoint_demo\"\n", | ||
"config = \\\n", | ||
"{\n", | ||
" \"model_type\": \"llama\",\n", | ||
" \"num_experts_per_tok\": 2,\n", | ||
" \"experts\":[\n", | ||
" {\n", | ||
" \"expert_name\" : \"base_expert\",\n", | ||
" \"model_id\" : \"unsloth/llama-3-8b\"\n", | ||
" },\n", | ||
" {\n", | ||
" \"expert_name\" : \"expert_1\",\n", | ||
" \"model_id\" : \"Locutusque/Llama-3-Orca-1.0-8B\"\n", | ||
" },\n", | ||
" {\n", | ||
" \"expert_name\" : \"expert_2\",\n", | ||
" \"model_id\" : \"mlabonne/OrpoLlama-3-8B\"\n", | ||
" }\n", | ||
" ],\n", | ||
" \"router_layers\":[\n", | ||
" \"gate_proj\",\n", | ||
" \"up_proj\",\n", | ||
" \"down_proj\"\n", | ||
" ],\n", | ||
"}\n", | ||
"# create checkpoint\n", | ||
"expertmerger = ComposeExperts( config, torch_dtype=torch.float16 )\n", | ||
"expertmerger.compose()\n", | ||
"expertmerger.save_checkpoint(\"data/llama3_moe\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Training\n", | ||
"\n", | ||
"Now that we have created an MOE checkpoint, all the layers of this model are pretrained except for the gating/routing layers that we added. The routing layer selects the top K experts, in our case K=2. We support HuggingFace trainers: Trainer, SFTrainer. In this example, we are using the [alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset for finetuning. We will train only the router layers, keeping all the other layers frozen." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Loading checkpoint shards: 100%|██████████| 5/5 [00:08<00:00, 1.67s/it]\n", | ||
"Some weights of LlamaForCausalLM were not initialized from the model checkpoint at data/llama3_moe and are newly initialized: ['model.layers.0.mlp.down_proj.gate.weight', 'model.layers.0.mlp.gate_proj.gate.weight', 'model.layers.0.mlp.up_proj.gate.weight', 'model.layers.1.mlp.down_proj.gate.weight', 'model.layers.1.mlp.gate_proj.gate.weight', 'model.layers.1.mlp.up_proj.gate.weight', 'model.layers.10.mlp.down_proj.gate.weight', 'model.layers.10.mlp.gate_proj.gate.weight', 'model.layers.10.mlp.up_proj.gate.weight', 'model.layers.11.mlp.down_proj.gate.weight', 'model.layers.11.mlp.gate_proj.gate.weight', 'model.layers.11.mlp.up_proj.gate.weight', 'model.layers.12.mlp.down_proj.gate.weight', 'model.layers.12.mlp.gate_proj.gate.weight', 'model.layers.12.mlp.up_proj.gate.weight', 'model.layers.13.mlp.down_proj.gate.weight', 'model.layers.13.mlp.gate_proj.gate.weight', 'model.layers.13.mlp.up_proj.gate.weight', 'model.layers.14.mlp.down_proj.gate.weight', 'model.layers.14.mlp.gate_proj.gate.weight', 'model.layers.14.mlp.up_proj.gate.weight', 'model.layers.15.mlp.down_proj.gate.weight', 'model.layers.15.mlp.gate_proj.gate.weight', 'model.layers.15.mlp.up_proj.gate.weight', 'model.layers.16.mlp.down_proj.gate.weight', 'model.layers.16.mlp.gate_proj.gate.weight', 'model.layers.16.mlp.up_proj.gate.weight', 'model.layers.17.mlp.down_proj.gate.weight', 'model.layers.17.mlp.gate_proj.gate.weight', 'model.layers.17.mlp.up_proj.gate.weight', 'model.layers.18.mlp.down_proj.gate.weight', 'model.layers.18.mlp.gate_proj.gate.weight', 'model.layers.18.mlp.up_proj.gate.weight', 'model.layers.19.mlp.down_proj.gate.weight', 'model.layers.19.mlp.gate_proj.gate.weight', 'model.layers.19.mlp.up_proj.gate.weight', 'model.layers.2.mlp.down_proj.gate.weight', 'model.layers.2.mlp.gate_proj.gate.weight', 'model.layers.2.mlp.up_proj.gate.weight', 'model.layers.20.mlp.down_proj.gate.weight', 'model.layers.20.mlp.gate_proj.gate.weight', 'model.layers.20.mlp.up_proj.gate.weight', 'model.layers.21.mlp.down_proj.gate.weight', 'model.layers.21.mlp.gate_proj.gate.weight', 'model.layers.21.mlp.up_proj.gate.weight', 'model.layers.22.mlp.down_proj.gate.weight', 'model.layers.22.mlp.gate_proj.gate.weight', 'model.layers.22.mlp.up_proj.gate.weight', 'model.layers.23.mlp.down_proj.gate.weight', 'model.layers.23.mlp.gate_proj.gate.weight', 'model.layers.23.mlp.up_proj.gate.weight', 'model.layers.24.mlp.down_proj.gate.weight', 'model.layers.24.mlp.gate_proj.gate.weight', 'model.layers.24.mlp.up_proj.gate.weight', 'model.layers.25.mlp.down_proj.gate.weight', 'model.layers.25.mlp.gate_proj.gate.weight', 'model.layers.25.mlp.up_proj.gate.weight', 'model.layers.26.mlp.down_proj.gate.weight', 'model.layers.26.mlp.gate_proj.gate.weight', 'model.layers.26.mlp.up_proj.gate.weight', 'model.layers.27.mlp.down_proj.gate.weight', 'model.layers.27.mlp.gate_proj.gate.weight', 'model.layers.27.mlp.up_proj.gate.weight', 'model.layers.28.mlp.down_proj.gate.weight', 'model.layers.28.mlp.gate_proj.gate.weight', 'model.layers.28.mlp.up_proj.gate.weight', 'model.layers.29.mlp.down_proj.gate.weight', 'model.layers.29.mlp.gate_proj.gate.weight', 'model.layers.29.mlp.up_proj.gate.weight', 'model.layers.3.mlp.down_proj.gate.weight', 'model.layers.3.mlp.gate_proj.gate.weight', 'model.layers.3.mlp.up_proj.gate.weight', 'model.layers.30.mlp.down_proj.gate.weight', 'model.layers.30.mlp.gate_proj.gate.weight', 'model.layers.30.mlp.up_proj.gate.weight', 'model.layers.31.mlp.down_proj.gate.weight', 'model.layers.31.mlp.gate_proj.gate.weight', 'model.layers.31.mlp.up_proj.gate.weight', 'model.layers.4.mlp.down_proj.gate.weight', 'model.layers.4.mlp.gate_proj.gate.weight', 'model.layers.4.mlp.up_proj.gate.weight', 'model.layers.5.mlp.down_proj.gate.weight', 'model.layers.5.mlp.gate_proj.gate.weight', 'model.layers.5.mlp.up_proj.gate.weight', 'model.layers.6.mlp.down_proj.gate.weight', 'model.layers.6.mlp.gate_proj.gate.weight', 'model.layers.6.mlp.up_proj.gate.weight', 'model.layers.7.mlp.down_proj.gate.weight', 'model.layers.7.mlp.gate_proj.gate.weight', 'model.layers.7.mlp.up_proj.gate.weight', 'model.layers.8.mlp.down_proj.gate.weight', 'model.layers.8.mlp.gate_proj.gate.weight', 'model.layers.8.mlp.up_proj.gate.weight', 'model.layers.9.mlp.down_proj.gate.weight', 'model.layers.9.mlp.gate_proj.gate.weight', 'model.layers.9.mlp.up_proj.gate.weight']\n", | ||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", | ||
"WARNING:root:Some parameters are on the meta device device because they were offloaded to the cpu.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# load the composed checkkpoint\n", | ||
"import torch\n", | ||
"from mergoo.models.modeling_llama import LlamaForCausalLM\n", | ||
"\n", | ||
"model = LlamaForCausalLM.from_pretrained(\n", | ||
" \"data/llama3_moe\", \n", | ||
" device_map=\"auto\", \n", | ||
" torch_dtype=torch.bfloat16,\n", | ||
")# 'gate' / router layers are untrained hence loaded warning would appeare for them" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(579, 387)" | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# train only router (gating) layers\n", | ||
"n_weights, n_router_weights = 0,0\n", | ||
"for name, weight in model.named_parameters():\n", | ||
" if \"gate\" not in name:\n", | ||
" weight.requires_grad_(False)\n", | ||
" n_router_weights += 1\n", | ||
" n_weights += 1\n", | ||
"n_weights, n_router_weights" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Downloading readme: 100%|██████████| 7.47k/7.47k [00:00<00:00, 22.7MB/s]\n", | ||
"Downloading data: 100%|██████████| 24.2M/24.2M [00:00<00:00, 56.3MB/s]\n", | ||
"Generating train split: 52002 examples [00:00, 531424.92 examples/s]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import datasets\n", | ||
"import random\n", | ||
"\n", | ||
"dataset = datasets.load_dataset(\"tatsu-lab/alpaca\")['train']\n", | ||
"dataset = dataset['text']\n", | ||
"random.shuffle(dataset)\n", | ||
"dataset_train = datasets.Dataset.from_dict(dict(prompt=dataset[:-1000]))\n", | ||
"dataset_test = datasets.Dataset.from_dict(dict(prompt=dataset[-1000:]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(Dataset({\n", | ||
" features: ['prompt'],\n", | ||
" num_rows: 51002\n", | ||
" }),\n", | ||
" Dataset({\n", | ||
" features: ['prompt'],\n", | ||
" num_rows: 1000\n", | ||
" }))" | ||
] | ||
}, | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"dataset_train, dataset_test" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", | ||
"/home/ubuntu/miniconda3/envs/router/lib/python3.12/site-packages/trl/trainer/sft_trainer.py:246: UserWarning: You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to 1024\n", | ||
" warnings.warn(\n", | ||
"Map: 100%|██████████| 51002/51002 [00:03<00:00, 16822.74 examples/s]\n", | ||
"Map: 100%|██████████| 1000/1000 [00:00<00:00, 18198.36 examples/s]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from trl import SFTTrainer\n", | ||
"from transformers import TrainingArguments\n", | ||
"\n", | ||
"trainer_args = TrainingArguments(\n", | ||
" output_dir= \"checkpoints/llama_moe\",\n", | ||
" per_device_train_batch_size = 1,\n", | ||
" per_device_eval_batch_size = 1, \n", | ||
" learning_rate= 1e-5,\n", | ||
" save_total_limit=1,\n", | ||
" num_train_epochs=1,\n", | ||
" eval_steps= 5000,\n", | ||
" logging_strategy=\"steps\",\n", | ||
" logging_steps= 25,\n", | ||
" gradient_accumulation_steps=4,\n", | ||
" bf16=True\n", | ||
")\n", | ||
"\n", | ||
"trainer = SFTTrainer(\n", | ||
" model,\n", | ||
" args= trainer_args,\n", | ||
" train_dataset= dataset_train,\n", | ||
" eval_dataset= dataset_test,\n", | ||
" dataset_text_field=\"prompt\",\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"trainer.train()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |