diff --git a/.readthedocs.yaml b/.readthedocs.yaml index e2be1893e..bab4a6dce 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -8,7 +8,7 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.12" + python: "3.9" mkdocs: configuration: mkdocs.yml diff --git a/examples/question_answer/bidaf_squad_scratch.ipynb b/examples/question_answer/bidaf_squad_scratch.ipynb index 83701b71f..30994ca6d 100644 --- a/examples/question_answer/bidaf_squad_scratch.ipynb +++ b/examples/question_answer/bidaf_squad_scratch.ipynb @@ -20,7 +20,7 @@ "from mindspore.common.initializer import Uniform, HeUniform, initializer\n", "\n", "import mindnlp\n", - "from mindnlp.abc import Seq2vecModel\n", + "from mindnlp._legacy.abc import Seq2vecModel\n", "from mindnlp.modules import Glove, StaticLSTM\n", "from mindnlp.transforms import BasicTokenizer\n", "\n", diff --git a/llm/peft/train_adalora_seq2seq/peft_adalora_seq2seq.ipynb b/llm/peft/adalora/train_adalora_seq2seq/peft_adalora_seq2seq.ipynb similarity index 100% rename from llm/peft/train_adalora_seq2seq/peft_adalora_seq2seq.ipynb rename to llm/peft/adalora/train_adalora_seq2seq/peft_adalora_seq2seq.ipynb diff --git a/llm/peft/train_mto_large_adaption_prompt/peft_adaption_prompt_CausalLM.ipynb b/llm/peft/adaption_prompt/train_mto_large_adaption_prompt/peft_adaption_prompt_CausalLM.ipynb similarity index 100% rename from llm/peft/train_mto_large_adaption_prompt/peft_adaption_prompt_CausalLM.ipynb rename to llm/peft/adaption_prompt/train_mto_large_adaption_prompt/peft_adaption_prompt_CausalLM.ipynb diff --git a/llm/peft/train_mt0_large_ia3/peft_ia3_seq2seq.ipynb b/llm/peft/ia3/train_mt0_large_ia3/peft_ia3_seq2seq.ipynb similarity index 100% rename from llm/peft/train_mt0_large_ia3/peft_ia3_seq2seq.ipynb rename to llm/peft/ia3/train_mt0_large_ia3/peft_ia3_seq2seq.ipynb diff --git a/llm/peft/train_mt0_large_lokr/peft_lokr_seq2seq.ipynb b/llm/peft/lokr/train_mt0_large_lokr/peft_lokr_seq2seq.ipynb similarity index 100% rename from llm/peft/train_mt0_large_lokr/peft_lokr_seq2seq.ipynb rename to llm/peft/lokr/train_mt0_large_lokr/peft_lokr_seq2seq.ipynb diff --git a/llm/peft/bert_mrpc_lora.py b/llm/peft/lora/bert_mrpc_lora.py similarity index 100% rename from llm/peft/bert_mrpc_lora.py rename to llm/peft/lora/bert_mrpc_lora.py diff --git a/llm/peft/train_convbert/squad_dataset.py b/llm/peft/lora/train_convbert/squad_dataset.py similarity index 100% rename from llm/peft/train_convbert/squad_dataset.py rename to llm/peft/lora/train_convbert/squad_dataset.py diff --git a/llm/peft/train_convbert/train.py b/llm/peft/lora/train_convbert/train.py similarity index 100% rename from llm/peft/train_convbert/train.py rename to llm/peft/lora/train_convbert/train.py diff --git a/llm/peft/train_falcon/mrpc_dataset.py b/llm/peft/lora/train_falcon/mrpc_dataset.py similarity index 100% rename from llm/peft/train_falcon/mrpc_dataset.py rename to llm/peft/lora/train_falcon/mrpc_dataset.py diff --git a/llm/peft/train_falcon/readme.md b/llm/peft/lora/train_falcon/readme.md similarity index 100% rename from llm/peft/train_falcon/readme.md rename to llm/peft/lora/train_falcon/readme.md diff --git a/llm/peft/train_falcon/train_mrpc.py b/llm/peft/lora/train_falcon/train_mrpc.py similarity index 100% rename from llm/peft/train_falcon/train_mrpc.py rename to llm/peft/lora/train_falcon/train_mrpc.py diff --git a/llm/peft/train_gpt_bigcode/bigcode_dataset.py b/llm/peft/lora/train_gpt_bigcode/bigcode_dataset.py similarity index 100% rename from llm/peft/train_gpt_bigcode/bigcode_dataset.py rename to llm/peft/lora/train_gpt_bigcode/bigcode_dataset.py diff --git a/llm/peft/train_gpt_bigcode/run_peft.sh b/llm/peft/lora/train_gpt_bigcode/run_peft.sh similarity index 100% rename from llm/peft/train_gpt_bigcode/run_peft.sh rename to llm/peft/lora/train_gpt_bigcode/run_peft.sh diff --git a/llm/peft/train_gpt_bigcode/train.py b/llm/peft/lora/train_gpt_bigcode/train.py similarity index 100% rename from llm/peft/train_gpt_bigcode/train.py rename to llm/peft/lora/train_gpt_bigcode/train.py diff --git a/llm/peft/train_llama_lora/mrpc_dataset.py b/llm/peft/lora/train_llama_lora/mrpc_dataset.py similarity index 100% rename from llm/peft/train_llama_lora/mrpc_dataset.py rename to llm/peft/lora/train_llama_lora/mrpc_dataset.py diff --git a/llm/peft/train_llama_lora/train.py b/llm/peft/lora/train_llama_lora/train.py similarity index 100% rename from llm/peft/train_llama_lora/train.py rename to llm/peft/lora/train_llama_lora/train.py diff --git a/llm/peft/prompt_tuning/roberta_sequence_classification.ipynb b/llm/peft/prompt_tuning/roberta_sequence_classification.ipynb new file mode 100644 index 000000000..83d9c88bd --- /dev/null +++ b/llm/peft/prompt_tuning/roberta_sequence_classification.ipynb @@ -0,0 +1,646 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "7228a58b-4f81-4f5d-ac6c-d9439b3f4447", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: HF_ENDPOINT=https://hf-mirror.com\n" + ] + } + ], + "source": [ + "%env HF_ENDPOINT=https://hf-mirror.com" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9ff5004e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Building prefix dict from the default dictionary ...\n", + "Loading model from cache /tmp/jieba.cache\n", + "Loading model cost 0.612 seconds.\n", + "Prefix dict has been built successfully.\n" + ] + } + ], + "source": [ + "import argparse\n", + "import os\n", + "\n", + "import mindspore\n", + "from mindspore.experimental.optim import AdamW\n", + "from tqdm import tqdm\n", + "import evaluate\n", + "from mindnlp.dataset import load_dataset\n", + "from mindnlp.engine import set_seed\n", + "from mindnlp.transformers import AutoModelForSequenceClassification, AutoTokenizer\n", + "from mindnlp.modules.optimization import get_linear_schedule_with_warmup\n", + "from mindnlp.peft import (\n", + " get_peft_config,\n", + " get_peft_model,\n", + " get_peft_model_state_dict,\n", + " set_peft_model_state_dict,\n", + " PeftType,\n", + " PromptTuningConfig,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e32c4a9e", + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 32\n", + "model_name_or_path = \"roberta-large\"\n", + "task = \"mrpc\"\n", + "peft_type = PeftType.PROMPT_TUNING\n", + "num_epochs = 20" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "622fe9c8", + "metadata": {}, + "outputs": [], + "source": [ + "peft_config = PromptTuningConfig(task_type=\"SEQ_CLS\", num_virtual_tokens=10)\n", + "lr = 1e-3" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "74e9efe0", + "metadata": {}, + "outputs": [], + "source": [ + "if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n", + " padding_side = \"left\"\n", + "else:\n", + " padding_side = \"right\"\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n", + "if getattr(tokenizer, \"pad_token_id\") is None:\n", + " tokenizer.pad_token_id = tokenizer.eos_token_id" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "41a63e71-e7c4-4e5d-9e22-6953d981d4b8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'sentence1': Tensor(shape=[], dtype=String, value= 'Amrozi accused his brother , whom he called \" the witness \" , of deliberately distorting his evidence .'), 'sentence2': Tensor(shape=[], dtype=String, value= 'Referring to him as only \" the witness \" , Amrozi accused his brother of deliberately distorting his evidence .'), 'label': Tensor(shape=[], dtype=Int64, value= 1), 'idx': Tensor(shape=[], dtype=Int64, value= 0)}\n" + ] + } + ], + "source": [ + "datasets = load_dataset(\"glue\", task)\n", + "print(next(datasets['train'].create_dict_iterator()))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "bd2d7cd5-62b8-4b7a-ac69-338e6319152e", + "metadata": {}, + "outputs": [], + "source": [ + "from mindnlp.dataset import BaseMapFuction\n", + "\n", + "class MapFunc(BaseMapFuction):\n", + " def __call__(self, sentence1, sentence2, label, idx):\n", + " outputs = tokenizer(sentence1, sentence2, truncation=True, max_length=None)\n", + " return outputs['input_ids'], outputs['attention_mask'], label\n", + "\n", + "\n", + "def get_dataset(dataset, tokenizer):\n", + " input_colums=['sentence1', 'sentence2', 'label', 'idx']\n", + " output_columns=['input_ids', 'attention_mask', 'labels']\n", + " dataset = dataset.map(MapFunc(input_colums, output_columns),\n", + " input_colums, output_columns)\n", + " dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),\n", + " 'attention_mask': (None, 0)})\n", + " return dataset\n", + "\n", + "train_dataset = get_dataset(datasets['train'], tokenizer)\n", + "eval_dataset = get_dataset(datasets['validation'], tokenizer)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "3b1fd5fc-2285-409e-a4e5-cc3c9759d77a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': Tensor(shape=[32, 70], dtype=Int64, value=\n", + "[[ 0, 10127, 1001 ... 1, 1, 1],\n", + " [ 0, 975, 26802 ... 1, 1, 1],\n", + " [ 0, 1213, 56 ... 1, 1, 1],\n", + " ...\n", + " [ 0, 133, 1154 ... 1, 1, 1],\n", + " [ 0, 12667, 8423 ... 1, 1, 1],\n", + " [ 0, 32478, 1033 ... 1, 1, 1]]), 'attention_mask': Tensor(shape=[32, 70], dtype=Int64, value=\n", + "[[1, 1, 1 ... 0, 0, 0],\n", + " [1, 1, 1 ... 0, 0, 0],\n", + " [1, 1, 1 ... 0, 0, 0],\n", + " ...\n", + " [1, 1, 1 ... 0, 0, 0],\n", + " [1, 1, 1 ... 0, 0, 0],\n", + " [1, 1, 1 ... 0, 0, 0]]), 'labels': Tensor(shape=[32], dtype=Int64, value= [1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, \n", + " 1, 1, 0, 0, 1, 1, 1, 0])}\n" + ] + } + ], + "source": [ + "print(next(train_dataset.create_dict_iterator()))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "efb606a2-1fb5-415c-bf12-7e6fd324fe0a", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "metric = evaluate.load(\"glue\", task)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a3c15af0", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The following parameters in checkpoint files are not loaded:\n", + "['lm_head.bias', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'roberta.embeddings.position_ids']\n", + "The following parameters in models are missing parameter:\n", + "['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias']\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "trainable params: 2,113,540 || all params: 356,423,684 || trainable%: 0.5929852854559463\n" + ] + } + ], + "source": [ + "model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)\n", + "model = get_peft_model(model, peft_config)\n", + "model.print_trainable_parameters()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "6d3c5edb", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = AdamW(params=model.trainable_params(), lr=lr)\n", + "\n", + "# Instantiate scheduler\n", + "lr_scheduler = get_linear_schedule_with_warmup(\n", + " optimizer=optimizer,\n", + " num_warmup_steps=0.06 * (len(train_dataset) * num_epochs),\n", + " num_training_steps=(len(train_dataset) * num_epochs),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "dbd66774-4482-448d-a1ee-f09f33cb8579", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Parameter (Tensor(shape=[1024, 1024], dtype=Float32, value=[...], name=base_model.classifier.original_module.dense.weight), requires_grad=True),\n", + " Parameter (Tensor(shape=[1024], dtype=Float32, value=[...], name=base_model.classifier.original_module.dense.bias), requires_grad=True),\n", + " Parameter (Tensor(shape=[2, 1024], dtype=Float32, value=[...], name=base_model.classifier.original_module.out_proj.weight), requires_grad=True),\n", + " Parameter (Tensor(shape=[2], dtype=Float32, value=[0. 0.], name=base_model.classifier.original_module.out_proj.bias), requires_grad=True),\n", + " Parameter (Tensor(shape=[1024, 1024], dtype=Float32, value=[...], name=base_model.classifier.modules_to_save.default.dense.weight), requires_grad=True),\n", + " Parameter (Tensor(shape=[1024], dtype=Float32, value=[...], name=base_model.classifier.modules_to_save.default.dense.bias), requires_grad=True),\n", + " Parameter (Tensor(shape=[2, 1024], dtype=Float32, value=[...], name=base_model.classifier.modules_to_save.default.out_proj.weight), requires_grad=True),\n", + " Parameter (Tensor(shape=[2], dtype=Float32, value=[ 0.00000000e+00 0.00000000e+00], name=base_model.classifier.modules_to_save.default.out_proj.bias), requires_grad=True),\n", + " Parameter (Tensor(shape=[10, 1024], dtype=Float32, value=[...], name=prompt_encoder.default.embedding.weight), requires_grad=True)]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.trainable_params()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "4d279225", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:02<00:00, 1.83it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 5.37it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 0: {'accuracy': 0.6446078431372549, 'f1': 0.7377938517179023}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:02<00:00, 1.84it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.33it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 1: {'accuracy': 0.7254901960784313, 'f1': 0.8181818181818181}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:03<00:00, 1.82it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.17it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 2: {'accuracy': 0.7426470588235294, 'f1': 0.8335974643423137}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:03<00:00, 1.81it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.22it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 3: {'accuracy': 0.7009803921568627, 'f1': 0.8157099697885196}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:02<00:00, 1.85it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.18it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 4: {'accuracy': 0.7058823529411765, 'f1': 0.8142414860681114}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:03<00:00, 1.80it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.19it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 5: {'accuracy': 0.7058823529411765, 'f1': 0.8159509202453988}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:04<00:00, 1.79it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.18it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 6: {'accuracy': 0.696078431372549, 'f1': 0.8037974683544306}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:03<00:00, 1.80it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.18it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 7: {'accuracy': 0.6838235294117647, 'f1': 0.7895595432300163}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:03<00:00, 1.80it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.20it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 8: {'accuracy': 0.7303921568627451, 'f1': 0.8338368580060422}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:03<00:00, 1.80it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.18it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 9: {'accuracy': 0.7181372549019608, 'f1': 0.8238897396630934}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:03<00:00, 1.80it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.18it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 10: {'accuracy': 0.7107843137254902, 'f1': 0.7993197278911565}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:03<00:00, 1.81it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.12it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 11: {'accuracy': 0.7083333333333334, 'f1': 0.8221225710014948}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:03<00:00, 1.82it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.18it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 12: {'accuracy': 0.7083333333333334, 'f1': 0.8205128205128206}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:03<00:00, 1.81it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.21it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 13: {'accuracy': 0.7205882352941176, 'f1': 0.8277945619335348}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:03<00:00, 1.81it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.18it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 14: {'accuracy': 0.7156862745098039, 'f1': 0.8263473053892216}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:02<00:00, 1.85it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.12it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 15: {'accuracy': 0.7254901960784313, 'f1': 0.8297872340425533}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:02<00:00, 1.84it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.19it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 16: {'accuracy': 0.7279411764705882, 'f1': 0.8294930875576038}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:04<00:00, 1.79it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.21it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 17: {'accuracy': 0.7254901960784313, 'f1': 0.8308157099697885}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:04<00:00, 1.80it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.17it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 18: {'accuracy': 0.7156862745098039, 'f1': 0.8237082066869301}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████| 115/115 [01:03<00:00, 1.80it/s]\n", + "100%|███████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.20it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 19: {'accuracy': 0.7156862745098039, 'f1': 0.8237082066869301}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "def forward_fn(**batch):\n", + " outputs = model(**batch)\n", + " loss = outputs.loss\n", + " return loss\n", + "\n", + "grad_fn = mindspore.value_and_grad(forward_fn, None, model.trainable_params())\n", + "\n", + "def train_step(**batch):\n", + " loss, grads = grad_fn(**batch)\n", + " optimizer(grads)\n", + " return loss\n", + "\n", + "for epoch in range(num_epochs):\n", + " model.set_train()\n", + " train_total_size = train_dataset.get_dataset_size()\n", + " for step, batch in enumerate(tqdm(train_dataset.create_dict_iterator(), total=train_total_size)):\n", + " loss = train_step(**batch)\n", + " lr_scheduler.step()\n", + "\n", + " model.set_train(False)\n", + " eval_total_size = eval_dataset.get_dataset_size()\n", + " for step, batch in enumerate(tqdm(eval_dataset.create_dict_iterator(), total=eval_total_size)):\n", + " outputs = model(**batch)\n", + " predictions = outputs.logits.argmax(axis=-1)\n", + " predictions, references = predictions, batch[\"labels\"]\n", + " metric.add_batch(\n", + " predictions=predictions,\n", + " references=references,\n", + " )\n", + "\n", + " eval_metric = metric.compute()\n", + " print(f\"epoch {epoch}:\", eval_metric)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + }, + "vscode": { + "interpreter": { + "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mindnlp/__init__.py b/mindnlp/__init__.py index 193952dc8..3ed15460d 100644 --- a/mindnlp/__init__.py +++ b/mindnlp/__init__.py @@ -24,8 +24,7 @@ from mindspore import jit as ms_jit from mindnlp import injection from mindnlp import transformers -from mindnlp.dataset import load_dataset -from mindnlp.workflow.workflow import Workflow -from mindnlp.vocab import Vocab +from mindnlp import dataset +from mindnlp import evaluate -__all__ = ['ms_jit', 'load_dataset', 'Workflow', 'Vocab'] +__all__ = ['ms_jit'] diff --git a/mindnlp/evaluate.py b/mindnlp/evaluate.py new file mode 100644 index 000000000..57a917780 --- /dev/null +++ b/mindnlp/evaluate.py @@ -0,0 +1,54 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +evaluate module. +""" +from typing import Optional, Union +from datasets import DownloadConfig, DownloadMode +from datasets.utils.version import Version +from evaluate import config +from evaluate import load as eval_load +from evaluate.module import EvaluationModule + +config.HUB_EVALUATE_URL = "https://openi.pcl.ac.cn/{path}/raw/branch/{revision}/{name}" + +def load( + path: str, + config_name: Optional[str] = None, + module_type: Optional[str] = None, + process_id: int = 0, + num_process: int = 1, + cache_dir: Optional[str] = None, + experiment_id: Optional[str] = None, + keep_in_memory: bool = False, + download_config: Optional[DownloadConfig] = None, + download_mode: Optional[DownloadMode] = None, + revision: Optional[Union[str, Version]] = None, + **init_kwargs, +) -> EvaluationModule: + return eval_load( + path, + config_name, + module_type, + process_id, + num_process, + cache_dir, + experiment_id, + keep_in_memory, + download_config, + download_mode, + revision, + **init_kwargs, + ) diff --git a/mindnlp/injection.py b/mindnlp/injection.py index a3fe9c00c..40c34f81b 100644 --- a/mindnlp/injection.py +++ b/mindnlp/injection.py @@ -17,7 +17,7 @@ Injection mindspore.nn for MindNLP """ import operator -from typing import OrderedDict +from typing import OrderedDict, List from functools import reduce, partial import math import types @@ -25,18 +25,19 @@ import mindspore.experimental import mindspore.experimental.optim from packaging import version - import numpy as np import mindspore import mindspore.common.dtype as mstype from mindspore._c_expression import Tensor as Tensor_ -from mindspore import nn, ops, Tensor, Parameter +from mindspore import nn, ops, Tensor, Parameter, ParameterTuple from mindspore.common._stub_tensor import StubTensor from mindspore.nn.layer.conv import _Conv, _deconv_output_length from mindspore.common.initializer import initializer, Normal, HeUniform, Uniform, _calculate_fan_in_and_fan_out from mindspore import _checkparam as Validator from mindspore.ops import functional as F from mindspore.ops._primitive_cache import _get_cache_prim +from mindspore.common.parameter import PARAMETER_NAME_DEFAULT + from mindnlp._legacy.functional import einsum from .utils.logging import get_logger from .amp import OP_WHITE_LIST, OP_BLACK_LIST, CELL_WHITE_LIST, get_global_amp @@ -1201,3 +1202,87 @@ def _cell_call(self, *args, **kwargs): return old_cell_call(self, *args, **kwargs) nn.Cell.__call__ = old_cell_call + +def get_cell(self, target): + if target == "": + return self + + atoms: List[str] = target.split(".") + mod: nn.Cell = self + + for item in atoms: + + if not hasattr(mod, item): + raise AttributeError(mod._get_name() + " has no " + "attribute `" + item + "`") + + mod = getattr(mod, item) + + if not isinstance(mod, nn.Cell): + raise AttributeError("`" + item + "` is not " + "an nn.Cell") + + return mod + +nn.Cell.get_cell = get_cell + +def _set_attr_for_parameter_in_list_or_tuple(self, name, value): + """Set attr for parameter in list or tuple.""" + for item in value: + if item in self.exist_objs: + # If there are multiple identical objects, their names only check once. + continue + self.exist_objs.add(item) + if item.name == PARAMETER_NAME_DEFAULT: + item.name = item.name + "$" + str(self._id) + self._id += 1 + object.__setattr__(self, name, value) + +nn.Cell._set_attr_for_parameter_in_list_or_tuple = _set_attr_for_parameter_in_list_or_tuple + +def _set_attr_for_parameter_tuple(self, name, value): + """Set attr for parameter in ParameterTuple.""" + params = self.__dict__.get('_params') + params_list = self.__dict__.get('_params_list') + if params is None: + raise AttributeError("For 'Cell', can not assign params before Cell.__init__() is called.") + exist_objs = set() + for item in value: + if item in exist_objs: + # If there are multiple identical objects, their names only check once. + continue + exist_objs.add(item) + if item.name == PARAMETER_NAME_DEFAULT: + logger.warning(f"For 'Cell', the parameter definition is deprecated.\n" + f"Please set a unique name for the parameter in ParameterTuple '{value}'.") + item.name = item.name + "$" + str(self._id) + self._id += 1 + self.insert_param_to_cell(item.name, item, check_name_contain_dot=False) + + if name in self.__dict__: + del self.__dict__[name] + if name in params: + del params[name] + params_list[name] = value + +nn.Cell._set_attr_for_parameter_tuple = _set_attr_for_parameter_tuple + +def _check_names(self): + pass + +nn.Cell.check_names = _check_names + +def __new__(cls, iterable): + """Create instance object of ParameterTuple.""" + data = tuple(iterable) + ids = set() + for x in data: + if not isinstance(x, Parameter): + raise TypeError(f"For ParameterTuple initialization, " + f"ParameterTuple input should be 'Parameter' collection, " + f"but got a {type(iterable)}. ") + if id(x) not in ids: + ids.add(id(x)) + return tuple.__new__(ParameterTuple, tuple(data)) + +ParameterTuple.__new__ = __new__ diff --git a/mindnlp/peft/__init__.py b/mindnlp/peft/__init__.py index 579bf5a02..a18ba786c 100644 --- a/mindnlp/peft/__init__.py +++ b/mindnlp/peft/__init__.py @@ -41,6 +41,7 @@ LoKrModel, AdaLoraConfig, AdaLoraModel, + PromptTuningConfig ) from .utils import ( diff --git a/mindnlp/peft/mapping.py b/mindnlp/peft/mapping.py index a20db4002..7ff7d1fd6 100644 --- a/mindnlp/peft/mapping.py +++ b/mindnlp/peft/mapping.py @@ -37,6 +37,7 @@ IA3Model, LoKrConfig, LoKrModel, + PromptTuningConfig, ) MODEL_TYPE_TO_PEFT_MODEL_MAPPING = { @@ -51,7 +52,7 @@ PEFT_TYPE_TO_CONFIG_MAPPING = { # "ADAPTION_PROMPT": AdaptionPromptConfig, - # "PROMPT_TUNING": PromptTuningConfig, + "PROMPT_TUNING": PromptTuningConfig, # "PREFIX_TUNING": PrefixTuningConfig, # "P_TUNING": PromptEncoderConfig, "ADAPTION_PROMPT": AdaptionPromptConfig, diff --git a/mindnlp/peft/peft_model.py b/mindnlp/peft/peft_model.py index b16e46fb5..36c4fb646 100644 --- a/mindnlp/peft/peft_model.py +++ b/mindnlp/peft/peft_model.py @@ -18,29 +18,31 @@ import inspect from contextlib import contextmanager from copy import deepcopy -from typing import Dict +from typing import Dict, Optional +import mindspore from mindspore import nn, ops from mindspore.train.serialization import _exec_save -from mindspore.nn import CrossEntropyLoss from .config import PeftConfig, PromptLearningConfig +from .._legacy.abc import CellDict +from ..transformers import PreTrainedModel from .tuners import ( AdaLoraModel, AdaptionPromptModel, LoraModel, IA3Model, - LoKrModel - # LoraConfig + LoKrModel, + # LoraConfig, + PromptEmbedding ) from .utils import ( # SAFETENSORS_WEIGHTS_NAME, - # TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, + TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, WEIGHTS_NAME, PeftType, - # TaskType, - # _get_batch_size, + TaskType, _prepare_prompt_learning_config, # _set_adapter, _set_trainable, @@ -50,7 +52,7 @@ load_peft_weights, set_peft_model_state_dict, shift_tokens_right, - # _get_batch_size, # will be used for prompt learning methods + _get_batch_size, # will be used for prompt learning methods ) @@ -78,7 +80,7 @@ def __init__(self, model, peft_config: PeftConfig, adapter_name="default"): self.peft_config: Dict[str, PeftConfig] = {} self.active_adapter = adapter_name self.peft_type = peft_config.peft_type - # self.base_model_torch_dtype = getattr(model, "dtype", None) + self.base_model_dtype = getattr(model, "dtype", None) if not peft_config.is_prompt_learning: self.peft_config[adapter_name] = peft_config self.base_model = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type]( @@ -160,6 +162,47 @@ def from_pretrained(cls, model, model_id, adapter_name="default", is_trainable=F model.load_adapter(model_id, adapter_name, **kwargs) return model + def _setup_prompt_encoder(self, adapter_name: str): + config = self.peft_config[adapter_name] + if not hasattr(self, "prompt_encoder"): + self.prompt_encoder = CellDict({}) + self.prompt_tokens = {} + transformer_backbone = None + for name, module in self.base_model.cells_and_names(): + for param in module.get_parameters(): + param.requires_grad = False + if isinstance(module, PreTrainedModel): + # Make sure to freeze Tranformers model + if transformer_backbone is None: + transformer_backbone = module + self.transformer_backbone_name = name + if transformer_backbone is None: + transformer_backbone = self.base_model + + if config.num_transformer_submodules is None: + config.num_transformer_submodules = 2 if config.task_type == TaskType.SEQ_2_SEQ_LM else 1 + + for named_param, value in list(transformer_backbone.parameters_and_names()): + + if value.shape[0] == self.base_model.config.vocab_size: + self.word_embeddings = transformer_backbone.get_cell(named_param.replace(".weight", "")) + break + + if config.peft_type == PeftType.PROMPT_TUNING: + prompt_encoder = PromptEmbedding(config, self.word_embeddings) + # elif config.peft_type == PeftType.MULTITASK_PROMPT_TUNING: + # prompt_encoder = MultitaskPromptEmbedding(config, self.word_embeddings) + # elif config.peft_type == PeftType.P_TUNING: + # prompt_encoder = PromptEncoder(config) + # elif config.peft_type == PeftType.PREFIX_TUNING: + # prompt_encoder = PrefixEncoder(config) + else: + raise ValueError("Not supported") + + self.prompt_encoder.update(CellDict({adapter_name: prompt_encoder})) + self.prompt_tokens[adapter_name] = ops.arange( + config.num_virtual_tokens * config.num_transformer_submodules + ).long() def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = False, **kwargs): """load adapter to peft model, called by `model.from_pretrained`.""" @@ -204,6 +247,50 @@ def get_nb_trainable_parameters(self): return trainable_params, all_param + def get_prompt(self, batch_size: int, task_ids: Optional[mindspore.Tensor] = None) -> mindspore.Tensor: + """ + Returns the virtual prompts to use for Peft. Only applicable when using a prompt learning method. + """ + peft_config = self.active_peft_config + prompt_encoder = self.prompt_encoder[self.active_adapter] + prompt_tokens = ( + self.prompt_tokens[self.active_adapter] + .unsqueeze(0) + .expand(batch_size, -1) + ) + if peft_config.peft_type == PeftType.PREFIX_TUNING: + prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens] + if peft_config.inference_mode: + past_key_values = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1) + else: + past_key_values = prompt_encoder(prompt_tokens) + if self.base_model_dtype is not None: + past_key_values = past_key_values.to(self.base_model_dtype) + past_key_values = past_key_values.view( + batch_size, + peft_config.num_virtual_tokens, + peft_config.num_layers * 2, + peft_config.num_attention_heads, + peft_config.token_dim // peft_config.num_attention_heads, + ) + if peft_config.num_transformer_submodules == 2: + past_key_values = ops.cat([past_key_values, past_key_values], axis=2) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split( + peft_config.num_transformer_submodules * 2 + ) + if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None: + post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type] + past_key_values = post_process_fn(past_key_values) + return past_key_values + else: + if peft_config.peft_type == PeftType.MULTITASK_PROMPT_TUNING: + prompts = prompt_encoder(prompt_tokens, task_ids) + else: + if peft_config.inference_mode: + prompts = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1) + else: + prompts = prompt_encoder(prompt_tokens) + return prompts def print_trainable_parameters(self): """ @@ -345,41 +432,41 @@ def construct( **kwargs, ) - raise NotImplementedError - # batch_size = _get_batch_size(input_ids, inputs_embeds) - # if attention_mask is not None: - # # concat prompt attention mask - # prefix_attention_mask = ops.ones(batch_size, peft_config.num_virtual_tokens) - # attention_mask = ops.cat((prefix_attention_mask, attention_mask), axis=1) - # if kwargs.get("position_ids", None) is not None: - # warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") - # kwargs["position_ids"] = None - # kwargs.update( - # { - # "attention_mask": attention_mask, - # "labels": labels, - # "output_attentions": output_attentions, - # "output_hidden_states": output_hidden_states, - # "return_dict": return_dict, - # } - # ) + batch_size = _get_batch_size(input_ids, inputs_embeds) + if attention_mask is not None: + # concat prompt attention mask + prefix_attention_mask = ops.ones(batch_size, peft_config.num_virtual_tokens, dtype=attention_mask.dtype) + attention_mask = ops.cat((prefix_attention_mask, attention_mask), axis=1) + if kwargs.get("position_ids", None) is not None: + warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") + kwargs["position_ids"] = None + kwargs.update( + { + "attention_mask": attention_mask, + "labels": labels, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + } + ) # if peft_config.peft_type == PeftType.PREFIX_TUNING: # return self._prefix_tuning_forward(input_ids=input_ids, **kwargs) - # if kwargs.get("token_type_ids", None) is not None: - # kwargs["token_type_ids"] = ops.cat( - # ( - # ops.zeros(batch_size, peft_config.num_virtual_tokens), - # kwargs["token_type_ids"], - # ), - # axis=1, - # ).long() - # if inputs_embeds is None: - # inputs_embeds = self.word_embeddings(input_ids) - # prompts = self.get_prompt(batch_size=batch_size) - # prompts = prompts.to(inputs_embeds.dtype) - # inputs_embeds = ops.cat((prompts, inputs_embeds), axis=1) - # return self.base_model(inputs_embeds=inputs_embeds, **kwargs) + if kwargs.get("token_type_ids", None) is not None: + kwargs["token_type_ids"] = ops.cat( + ( + ops.zeros(batch_size, peft_config.num_virtual_tokens, dtype=kwargs["token_type_ids"].dtype), + kwargs["token_type_ids"], + ), + axis=1, + ) + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + prompts = self.get_prompt(batch_size=batch_size) + prompts = prompts.to(inputs_embeds.dtype) + inputs_embeds = ops.cat((prompts, inputs_embeds), axis=1) + return self.base_model(inputs_embeds=inputs_embeds, **kwargs) + class PeftModelForCausalLM(PeftModel): """ @@ -849,8 +936,7 @@ def _prefix_tuning_forward( loss = None if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + loss = ops.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output diff --git a/mindnlp/peft/tuners/__init__.py b/mindnlp/peft/tuners/__init__.py index 43e282538..a620578d2 100644 --- a/mindnlp/peft/tuners/__init__.py +++ b/mindnlp/peft/tuners/__init__.py @@ -18,4 +18,5 @@ from .ia3 import IA3Config, IA3Model from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel from .adalora import AdaLoraConfig, AdaLoraModel -from .lokr import LoKrConfig,LoKrModel +from .lokr import LoKrConfig, LoKrModel +from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit diff --git a/mindnlp/peft/tuners/adalora/config.py b/mindnlp/peft/tuners/adalora/config.py index eb77f769b..64857af30 100644 --- a/mindnlp/peft/tuners/adalora/config.py +++ b/mindnlp/peft/tuners/adalora/config.py @@ -19,8 +19,8 @@ from dataclasses import dataclass, field from typing import Optional -from mindnlp.peft.tuners import LoraConfig -from mindnlp.peft.utils import PeftType +from ..lora import LoraConfig +from ...utils import PeftType @dataclass diff --git a/mindnlp/peft/tuners/lokr/layer.py b/mindnlp/peft/tuners/lokr/layer.py index c284ab9e8..3dab4487b 100644 --- a/mindnlp/peft/tuners/lokr/layer.py +++ b/mindnlp/peft/tuners/lokr/layer.py @@ -23,7 +23,7 @@ from mindspore.common.initializer import initializer, HeUniform, Zero # import mindnlp._legacy.functional as F -from mindnlp.abc import ParameterDict +from mindnlp._legacy.abc import ParameterDict # from ..import_utils import is_bnb_4bit_available, is_bnb_available diff --git a/mindnlp/peft/tuners/lora.py b/mindnlp/peft/tuners/lora.py deleted file mode 100644 index e07a021a9..000000000 --- a/mindnlp/peft/tuners/lora.py +++ /dev/null @@ -1,855 +0,0 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Lora.""" -import math -import re -import warnings -from dataclasses import asdict, dataclass, field -from enum import Enum -from typing import List, Optional, Union - -import mindspore -from mindspore import nn, ops -from mindspore.common.initializer import initializer, HeUniform, Zero, Normal - -# import mindnlp._legacy.functional as F -from mindnlp.transformers.ms_utils import Conv1D -from mindnlp._legacy.abc import CellDict - -from ..config import PeftConfig -# from ..import_utils import is_bnb_4bit_available, is_bnb_available -from ..utils import ( - # CLAMP_QUANTILE, - COMMON_LAYERS_PATTERN, - TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, - ModulesToSaveWrapper, - PeftType, - # _freeze_adapter, - _get_submodules, - transpose, -) - -from .tuners_utils import BaseTuner - -# if is_bnb_available(): -# import bitsandbytes as bnb - - -@dataclass -class LoraConfig(PeftConfig): - """ - This is the configuration class to store the configuration of a [`LoraModel`]. - - Args: - r (`int`): Lora attention dimension. - target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to. - lora_alpha (`float`): The alpha parameter for Lora scaling. - lora_dropout (`float`): The dropout probability for Lora layers. - fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out). - For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.: - bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only' - modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable - and saved in the final checkpoint. - layers_to_transform (`Union[List[int],int]`): - The layer indexes to transform, if this argument is specified, it will apply the LoRA transformations on - the layer indexes that are specified in this list. If a single integer is passed, it will apply the LoRA - transformations on the layer at this index. - layers_pattern (`str`): - The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer - pattern is not in the common layers pattern. - """ - - r: int = field(default=8, metadata={"help": "Lora attention dimension"}) - target_modules: Optional[Union[List[str], str]] = field( - default=None, - metadata={ - "help": "List of module names or regex expression of the module names to replace with Lora." - "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " - }, - ) - lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"}) - lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"}) - fan_in_fan_out: bool = field( - default=False, - metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, - ) - bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"}) - modules_to_save: Optional[List[str]] = field( - default=None, - metadata={ - "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. " - "For example, in Sequence Classification or Token Classification tasks, " - "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." - }, - ) - init_lora_weights: bool = field( - default=True, - metadata={"help": "Whether to initialize the weights of the Lora layers."}, - ) - layers_to_transform: Optional[Union[List, int]] = field( - default=None, - metadata={ - "help": "The layer indexes to transform, is this argument is specified, \ - PEFT will transform only the layers indexes that are specified inside this list. \ - If a single integer is passed, PEFT will transform only the layer at this index." - }, - ) - layers_pattern: Optional[str] = field( - default=None, - metadata={ - "help": "The layer pattern name, used only if `layers_to_transform` is different to None and \ - if the layer pattern is not in the common layers pattern." - }, - ) - - def __post_init__(self): - self.peft_type = PeftType.LORA - - @property - def is_prompt_learning(self): - r""" - Utility method to check if the configuration is for prompt learning. - """ - return False - -class LoraModel(BaseTuner): - """ - Creates Low Rank Adapter (Lora) model from a pretrained transformers model. - - Args: - model ([`~mindnlp.PreTrainedModel`]): The model to be adapted. - config ([`LoraConfig`]): The configuration of the Lora model. - - Returns: - `mindspore.nn.Cell`: The Lora model. - - Example: - - ```py - >>> from mindnlp.transformers import GPTForSequenceClassification - >>> from mindnlp.modules import LoraModel, LoraConfig - - >>> config = LoraConfig( - ... peft_type="LORA", - ... task_type="SEQ_2_SEQ_LM", - ... r=8, - ... lora_alpha=32, - ... target_modules=["q", "v"], - ... lora_dropout=0.01, - ... ) - - >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - >>> lora_model = LoraModel(config, model) - ``` - """ - - def __init__(self, model: nn.Cell, config, adapter_name): - # call BaseTuner.__init__ - # setup config and inject lora adapter - super().__init__(model, config, adapter_name) - - @staticmethod - def _prepare_adapter_config(peft_config, model_config): - if peft_config.target_modules is None: - # If target_modules is not specified, use the default target_modules for the model type - if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: - raise ValueError("Please specify `target_modules` in `peft_config`") - peft_config.target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] - return peft_config - def _check_new_adapter_config(self, config: LoraConfig): - """ - A helper method to check the config when a new adapter is being added. - Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. - """ - # TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check - # does not fully correspond to the error message. - if (len(self.peft_config) > 1) and (config.bias != "none"): - raise ValueError( - f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " - "set bias to 'none' for all adapters." - ) - - @staticmethod - def _check_target_module_exists(lora_config, key): - if isinstance(lora_config.target_modules, str): - target_module_found = re.fullmatch(lora_config.target_modules, key) - else: - target_module_found = any( - re.match(f".*\.{target_key}$", key) for target_key in lora_config.target_modules - ) or any(target_key == key for target_key in lora_config.target_modules) - is_using_layer_indexes = getattr(lora_config, "layers_to_transform", None) is not None - layer_indexing_pattern = getattr(lora_config, "layers_pattern", None) - - if is_using_layer_indexes and target_module_found: - layers_pattern = COMMON_LAYERS_PATTERN if layer_indexing_pattern is None else layer_indexing_pattern - layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern - - for pattern in layers_pattern: - layer_index = re.match(f".*.{pattern}\.(\d+)\.*", key) - if layer_index is not None: - layer_index = int(layer_index.group(1)) - if isinstance(lora_config.layers_to_transform, int): - target_module_found = layer_index == lora_config.layers_to_transform - else: - target_module_found = layer_index in lora_config.layers_to_transform - - break - else: - target_module_found = False - return target_module_found - - def _create_and_replace( - self, - lora_config, - adapter_name, - target, - target_name, - parent, - **optionnal_kwargs, - ): - bias = hasattr(target, "bias") and target.bias is not None - kwargs = { - "r": lora_config.r, - "lora_alpha": lora_config.lora_alpha, - "lora_dropout": lora_config.lora_dropout, - "fan_in_fan_out": lora_config.fan_in_fan_out, - "init_lora_weights": lora_config.init_lora_weights, - } - - kwargs["loaded_in_8bit"] = optionnal_kwargs.pop("loaded_in_8bit", False) - kwargs["loaded_in_4bit"] = optionnal_kwargs.pop("loaded_in_4bit", False) - kwargs["bias"] = bias - - # TODO: better deal with that - # if isinstance(target, LoraLayer) and isinstance(target, torch.nn.Conv2d): - # target.update_layer_conv2d( - # adapter_name, - # lora_config.r, - # lora_config.lora_alpha, - # lora_config.lora_dropout, - # lora_config.init_lora_weights, - # ) - if isinstance(target, LoraLayer) and isinstance(target, mindspore.nn.Embedding): - target.update_layer_embedding( - adapter_name, - lora_config.r, - lora_config.lora_alpha, - lora_config.lora_dropout, - lora_config.init_lora_weights, - ) - - elif isinstance(target, LoraLayer): - target.update_layer( - adapter_name, - lora_config.r, - lora_config.lora_alpha, - lora_config.lora_dropout, - lora_config.init_lora_weights, - ) - else: - new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) - self._replace_module(parent, target_name, new_module, target) - - @staticmethod - def _replace_module(parent, child_name, new_module, child): - setattr(parent, child_name, new_module) - if isinstance(parent, nn.SequentialCell): - parent.cell_list = list(parent._cells.values()) - - new_module.weight = child.weight - if hasattr(child, "bias"): - if child.bias is not None: - new_module.bias = child.bias - - if getattr(child, "state", None) is not None: - new_module.state = child.state - # TODO: .to(device) not support in mindspore - # new_module.to(child.weight.device) # error - - # TODO: dispatch to correct device - # for name, module in new_module.parameters_and_names(): - # if "lora_" in name: - # module.to(child.weight.device) # error - # if "ranknum" in name: - # module.to(child.weight.device) # error - - - def __getattr__(self, name: str): - """Forward missing attributes to the wrapped module.""" - try: - return super().__getattr__(name) # defer to nn.Module's logic - except AttributeError: - return getattr(self.model, name) - - def get_peft_config_as_dict(self, inference: bool = False): - """get peft config as dict""" - config_dict = {} - for key, value in self.peft_config.items(): - config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} - if inference: - config["inference_mode"] = True - config_dict[key] = config # pylint: disable=undefined-loop-variable - return config - - def _set_adapter_layers(self, enabled=True): - for module in self.model.cells(): - module.disable_adapters = not isinstance(module, LoraLayer) - - def enable_adapter_layers(self): - """enable_adapter_layers""" - self._set_adapter_layers(enabled=True) - - def disable_adapter_layers(self): - """disable_adapter_layers""" - self._set_adapter_layers(enabled=False) - - def set_adapter(self, adapter_name): - """set_adapter""" - for module in self.model.modules(): - if isinstance(module, LoraLayer): - if module.merged: - warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") - module.unmerge() - module.active_adapter = adapter_name - - def merge_adapter(self): - """merge_adapter""" - for module in self.model.modules(): - if isinstance(module, LoraLayer): - module.merge() - - def unmerge_adapter(self): - """unmerge_adapter""" - for module in self.model.modules(): - if isinstance(module, LoraLayer): - module.unmerge() - - @staticmethod - def _prepare_lora_config(peft_config, model_config): - if peft_config.target_modules is None: - if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: - raise ValueError("Please specify `target_modules` in `peft_config`") - peft_config.target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] - - if peft_config.inference_mode: - peft_config.merge_weights = True - - return peft_config - - def merge_and_unload(self): - r""" - This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model - as a standalone model. - """ - if getattr(self.config, "model_type", None) == "gpt2": - raise ValueError("GPT2 models are not supported for merging LORA layers") - - if getattr(self.model, "is_loaded_in_8bit", False): - raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode") - - key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] - for key in key_list: - try: - parent, target, target_name = _get_submodules(self.model, key) - except AttributeError: - continue - if isinstance(target, LoraLayer): - if isinstance(target, nn.Embedding): - new_module = nn.Embedding(target.in_channels, target.out_channels) - elif isinstance(target, nn.Conv2d): - new_module = nn.Conv2d( - target.in_channels, - target.out_channels, - kernel_size=target.kernel_size, - stride=target.stride, - padding=target.padding, - dilation=target.dilation, - ) - elif isinstance(target, nn.Dense): - bias = target.bias is not None - new_module = nn.Dense(target.in_channels, target.out_channels, has_bias=bias) - else: - raise ValueError(f"Not support {type(target)}.") - target.merge() - self._replace_module(parent, target_name, new_module, target) - - # save any additional trainable modules part of `modules_to_save` - if isinstance(target, ModulesToSaveWrapper): - setattr(parent, target_name, target.modules_to_save[target.active_adapter]) - - return self.model - - # def add_weighted_adapter(self, adapters, weights, adapter_name): - # """add_weighted_adapter""" - # if len({self.peft_config[adapter].r for adapter in adapters}) != 1: - # raise ValueError("All adapters must have the same r value") - # self.peft_config[adapter_name] = replace( - # self.peft_config[adapters[0]], lora_alpha=self.peft_config[adapters[0]].r - # ) - # self._find_and_replace(adapter_name) - # mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias) - # _freeze_adapter(self.model, adapter_name) - # key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] - # for key in key_list: - # _, target, _ = _get_submodules(self.model, key) - # if isinstance(target, LoraLayer): - # if adapter_name in target.lora_A: - # target.lora_A[adapter_name].weight.data = target.lora_A[adapter_name].weight.data * 0.0 - # target.lora_B[adapter_name].weight.data = target.lora_B[adapter_name].weight.data * 0.0 - # for adapter, weight in zip(adapters, weights): - # if adapter not in target.lora_A: - # continue - # target.lora_A[adapter_name].weight.data += ( - # target.lora_A[adapter].weight.data * weight * target.scaling[adapter] - # ) - # target.lora_B[adapter_name].weight.data += target.lora_B[adapter].weight.data * weight - - # elif adapter_name in target.lora_embedding_A: - # target.lora_embedding_A[adapter_name].data = target.lora_embedding_A[adapter_name].data * 0.0 - # target.lora_embedding_B[adapter_name].data = target.lora_embedding_B[adapter_name].data * 0.0 - # for adapter, weight in zip(adapters, weights): - # if adapter not in target.lora_embedding_A: - # continue - # target.lora_embedding_A[adapter_name].data += ( - # target.lora_embedding_A[adapter].data * weight * target.scaling[adapter] - # ) - # target.lora_embedding_B[adapter_name].data += target.lora_embedding_B[adapter].data * weight - - def _get_active_adapter(self) -> str: - active_adapter = None - for _, module in self.model.cells_and_names(): - if isinstance(module, LoraLayer): - active_adapter = module.active_adapter - - if active_adapter is None: - raise ValueError( - "Something went wrong, no active adapter could be found, please report the issue on GitHub" - ) - return active_adapter - - def _mark_only_adapters_as_trainable(self,model) -> None: - """mark_only_lora_as_trainable""" - # get bias - active_adapter = self._get_active_adapter() - bias = self.peft_config[active_adapter].bias - - for n, p in model.parameters_and_names(): # named_parameters() -> parameters_and_names() - if "lora_" not in n: - p.requires_grad = False - # print(n, p, "requires_grad = False") - if bias == "none": - return - elif bias == "all": - for n, p in model.parameters_and_names(): - if "bias" in n: - p.requires_grad = True - elif bias == "lora_only": - for m in model.cells(): # .cells() for modules() - if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None: - m.bias.requires_grad = True - else: - raise NotImplementedError - - - @staticmethod - def _create_new_module( - lora_config: PeftConfig, - adapter_name: str, - target: mindspore.nn.Cell, - **kwargs - ): - """""" - # TODO: support loaded_in_8bit & loaded_in_4bit later, just pop now. - loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) - loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) - bias = kwargs.pop("bias", False) - - # if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): - # eightbit_kwargs = kwargs.copy() - # eightbit_kwargs.update( - # { - # "has_fp16_weights": target.state.has_fp16_weights, - # "memory_efficient_backward": target.state.memory_efficient_backward, - # "threshold": target.state.threshold, - # "index": target.index, - # } - # ) - # new_module = Linear8bitLt( - # adapter_name, target.in_channels, target.out_channels, bias=bias, **eightbit_kwargs - # ) - # elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit): - # fourbit_kwargs = kwargs.copy() - # fourbit_kwargs.update( - # { - # "compute_dtype": target.compute_dtype, - # "compress_statistics": target.weight.compress_statistics, - # "quant_type": target.weight.quant_type, - # } - # ) - # new_module = Linear4bit(adapter_name, target.in_channels, target.out_channels, bias=bias, **fourbit_kwargs) - if isinstance(target, nn.Embedding): - embedding_kwargs = kwargs.copy() - embedding_kwargs.pop("fan_in_fan_out", None) - in_features, out_features = target.vocab_size, target.embedding_size # target.num_embeddings, target.embedding_dim - new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs) - # elif isinstance(target, torch.nn.Conv2d): - # out_channels, in_channels = target.weight.size()[:2] - # kernel_size = target.weight.size()[2:] - # stride = target.stride`` - # padding = target.padding - # new_module = Conv2d(adapter_name, in_channels, out_channels, kernel_size, stride, padding, **kwargs) - else: - if isinstance(target, nn.Dense): # Linear - # get - in_features, out_features = target.in_channels, target.out_channels - if kwargs["fan_in_fan_out"]: - warnings.warn( - "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " - "Setting fan_in_fan_out to False." - ) - kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False - elif isinstance(target, Conv1D): - in_features, out_features = ( - target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape - ) - kwargs["is_target_conv_1d_layer"] = True - if not kwargs["fan_in_fan_out"]: - warnings.warn( - "fan_in_fan_out is set to False but the target module is `Conv1D`. " - "Setting fan_in_fan_out to True." - ) - kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True - else: - raise ValueError( - f"Target module {target} is not supported. " - f"Currently, only `torch.nn.Linear` and `Conv1D` are supported." - ) - new_module = Linear(adapter_name, in_features, out_features, has_bias=bias, **kwargs) - - return new_module - - -class LoraLayer(): - """Lora Layer""" - # TODO add CellDict Support - def __init__(self, in_features: int, out_features: int, **kwargs): - self.r = {} - self.lora_alpha = {} - self.scaling = {} - # TODO: there is no nn.CellDict() in mindspore - self.lora_dropout = CellDict() - self.lora_A = CellDict() - self.lora_B = CellDict() - # For Embedding layer - self.lora_embedding_A = CellDict() - self.lora_embedding_B = CellDict() - - # Mark the weight as unmerged - self.merged = False - self.disable_adapters = False - self.in_features = in_features - self.out_features = out_features - self.kwargs = kwargs - - def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): - """ - update lora layer. - """ - self.r[adapter_name] = r - self.lora_alpha[adapter_name] = lora_alpha - if lora_dropout > 0.0: - lora_dropout_layer = nn.Dropout(p=lora_dropout) - else: - lora_dropout_layer = nn.Identity() - - # self.lora_dropout.append({adapter_name: lora_dropout_layer}) - self.lora_dropout.update(CellDict({adapter_name: lora_dropout_layer})) - # Actual trainable parameters - if r > 0: - self.lora_A.update({adapter_name: nn.Dense(self.in_features, r, has_bias=False)}) - self.lora_B.update({adapter_name: nn.Dense(r, self.out_features, has_bias=False)}) - # self.lora_A.append(nn.Dense(self.in_features, r, has_bias=False)) - # self.lora_B.append(nn.Dense(r, self.out_features, has_bias=False)) - self.scaling[adapter_name] = lora_alpha / r - if init_lora_weights: - self.reset_lora_parameters(adapter_name) - # TODO: to device - # self.to(self.weight.device) - - # TODO: add conv2d - # def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): - # self.r[adapter_name] = r - # self.lora_alpha[adapter_name] = lora_alpha - # if lora_dropout > 0.0: - # lora_dropout_layer = nn.Dropout(p=lora_dropout) - # else: - # lora_dropout_layer = nn.Identity() - - # self.lora_dropout.update(CellList({adapter_name: lora_dropout_layer})) - # # Actual trainable parameters - # if r > 0: - # kernel_size = self.kwargs["kernel_size"] - # stride = self.kwargs["stride"] - # padding = self.kwargs["padding"] - # self.lora_A.update( - # CellList({adapter_name: nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False)}) - # ) - # self.lora_B.update( - # CellList({adapter_name: nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)}) - # ) - # self.scaling[adapter_name] = lora_alpha / r - # if init_lora_weights: - # self.reset_lora_parameters(adapter_name) - # self.to(self.weight.device) - - def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): - """ - update layer embedding. - """ - self.r[adapter_name] = r - self.lora_alpha[adapter_name] = lora_alpha - if lora_dropout > 0.0: - lora_dropout_layer = nn.Dropout(p=lora_dropout) - else: - lora_dropout_layer = nn.Identity() - - self.lora_dropout.update({adapter_name: lora_dropout_layer}) - # Actual trainable parameters - if r > 0: - weight_A = mindspore.ops.randn((r, self.in_features)) # dtype=self.weight.dtype, device=self.weight.device - weight_B = mindspore.ops.randn((self.out_features, r)) # dtype=self.weight.dtype, device=self.weight.device - self.lora_embedding_A.update({adapter_name: mindspore.Parameter(weight_A)}) - self.lora_embedding_B.update({adapter_name: mindspore.Parameter(weight_B)}) - self.scaling[adapter_name] = lora_alpha / r - if init_lora_weights: - self.reset_lora_parameters(adapter_name) - # self.to(self.weight.device) - - def reset_lora_parameters(self, adapter_name): - """ - reset lora parameters. - """ - if adapter_name in self.lora_A.keys(): - self.lora_A[adapter_name].weight.set_data(initializer( - HeUniform(negative_slope=math.sqrt(5)), - self.lora_A[adapter_name].weight.shape, - self.lora_A[adapter_name].weight.dtype - )) - self.lora_B[adapter_name].weight.set_data(initializer( - Zero(), - self.lora_B[adapter_name].weight.shape, - self.lora_B[adapter_name].weight.dtype - )) - - if adapter_name in self.lora_embedding_A.keys(): - # initialize a the same way as the default for nn.linear and b to zero - self.lora_embedding_A[adapter_name].weight.set_data(initializer( - Zero(), - self.lora_embedding_A[adapter_name].weight.shape, - self.lora_embedding_A[adapter_name].weight.dtype - )) - self.lora_embedding_B[adapter_name].weight.set_data(initializer( - Normal(), - self.lora_embedding_B[adapter_name].weight.shape, - self.lora_embedding_B[adapter_name].weight.dtype - )) - # TODO embedding not ok - # if adapter_name in self.lora_embedding_A.keys(): - # # initialize a the same way as the default for nn.Dense and b to zero - # Zero()(self.lora_embedding_A[adapter_name]) - # Normal(mean=0, sigma=1.)(self.lora_embedding_B[adapter_name]) - - -class Linear(nn.Dense, LoraLayer): - """Lora implemented in a dense layer""" - def __init__( - self, - adapter_name: str, - in_features: int, - out_features: int, - r: int = 0, - lora_alpha: int = 1, - lora_dropout: float = 0.0, - fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) - is_target_conv_1d_layer: bool = False, - **kwargs, - ): - init_lora_weights = kwargs.pop("init_lora_weights", True) - - nn.Dense.__init__(self, in_features, out_features, **kwargs) - LoraLayer.__init__(self, in_features=in_features, out_features=out_features) - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - - self.fan_in_fan_out = fan_in_fan_out - if fan_in_fan_out: - self.weight.data = self.weight.data.T - - # nn.Linear.reset_parameters(self) - self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) # call # LoraLayer.update_layer - self.active_adapter = adapter_name - self.is_target_conv_1d_layer = is_target_conv_1d_layer - - def merge(self): - """merge""" - if self.active_adapter not in self.lora_A.keys(): - return - if self.merged: - warnings.warn("Already merged. Nothing to do.") - return - if self.r[self.active_adapter] > 0: - self.weight.data += self.get_delta_weight(self.active_adapter) - self.merged = True - - def unmerge(self): - """unmerge""" - if self.active_adapter not in self.lora_A.keys(): - return - if not self.merged: - warnings.warn("Already unmerged. Nothing to do.") - return - if self.r[self.active_adapter] > 0: - self.weight.data -= self.get_delta_weight(self.active_adapter) - self.merged = False - - def get_delta_weight(self, adapter): - """ - get delta weight. Add or Sub to origin. - """ - return ( - transpose( - self.lora_B[adapter].weight @ self.lora_A[adapter].weight, - self.fan_in_fan_out, - ) - * self.scaling[adapter] - ) - - def extend_repr(self): - s = f'input_channels={self.in_channels}, output_channels={self.out_channels}' - if self.has_bias: - s += f', has_bias={self.has_bias}' - if self.activation_flag: - s += f', activation={self.activation}' - s += f', requires_grad={self.weight.requires_grad}' - return s - - def _linear(self, x: mindspore.Tensor) -> mindspore.Tensor: - return ops.dense(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) - - def construct(self, x: mindspore.Tensor): - if self.active_adapter not in self.lora_A.keys(): - return self._linear(x) - - previous_dtype = x.dtype - - if self.disable_adapters: - if (self.r[self.active_adapter] > 0) and self.merged: - self.unmerge() - result = self._linear(x) - elif (self.r[self.active_adapter] == 0) or self.merged: - result = self._linear(x) - else: - lora_A = self.lora_A[self.active_adapter] - lora_B = self.lora_B[self.active_adapter] - dropout = self.lora_dropout[self.active_adapter] - scaling = self.scaling[self.active_adapter] - - result = self._linear(x) - x = x.to(lora_A.weight.dtype) - result += lora_B(lora_A(dropout(x))) * scaling - - result = result.to(previous_dtype) - return result - - -class Embedding(nn.Embedding, LoraLayer): - """LoRA implemented in a Embedding layer""" - def __init__( - self, - adapter_name: str, - num_embeddings: int, - embedding_dim: int, - r: int = 0, - lora_alpha: int = 1, - lora_dropout: float = 0.0, - **kwargs, - ): - init_lora_weights = kwargs.pop("init_lora_weights", True) - - nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) - LoraLayer.__init__(self, in_features=num_embeddings, out_features=embedding_dim) - - self.weight.requires_grad = False - - # TODO: check nesissary - # check the api of mindspore.nn.Embedding initialization - # nn.Embedding.reset_parameters(self) - self.update_layer_embedding(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.active_adapter = adapter_name - - def unmerge(self, mode: bool = True): - """unmerge""" - if not self.merged: - warnings.warn("Already unmerged. Nothing to do.") - return - if self.r[self.active_adapter] > 0: - self.weight.data -= ( - transpose( - self.lora_embedding_B[self.active_adapter] @ self.lora_embedding_A[self.active_adapter], True - ) - * self.scaling[self.active_adapter] - ) - self.merged = False - - def merge(self): - """merge""" - if self.merged: - warnings.warn("Already merged. Nothing to do.") - return - if self.r[self.active_adapter] > 0: - self.weight.data += ( - transpose( - self.lora_embedding_B[self.active_adapter] @ self.lora_embedding_A[self.active_adapter], True - ) - * self.scaling[self.active_adapter] - ) - self.merged = True - - def construct(self, ids: mindspore.Tensor): - if self.disable_adapters: - if self.r[self.active.adapter] > 0 and self.merged: - self.weight.data -= ( - transpose( - self.lora_embedding_B[self.active_adapter].weight - @ self.lora_embedding_A[self.active_adapter].weight, - True, - ) - * self.scaling[self.active_adapter] - ) - self.merged = False - return nn.Embedding.construct(self, ids) - - elif self.r[self.active_adapter] > 0 and not self.merged: - result = nn.Embedding.construct(self, ids) - if self.r[self.active_adapter] > 0: - after_A = ops.gather( - self.lora_embedding_A[self.active_adapter].T, - ids, - 0 - ) - result += (after_A @ self.lora_embedding_B[self.active_adapter].T) * self.scaling[self.active_adapter] - return result - else: - return nn.Embedding.construct(self, ids) diff --git a/mindnlp/peft/tuners/lora/__init__.py b/mindnlp/peft/tuners/lora/__init__.py new file mode 100644 index 000000000..56606cccf --- /dev/null +++ b/mindnlp/peft/tuners/lora/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""lora""" +from .config import LoftQConfig, LoraConfig +from .layer import Conv2d, Embedding, Linear, LoraLayer +from .model import LoraModel + + +__all__ = ["LoraConfig", "LoftQConfig", "Conv2d", "Embedding", "LoraLayer", "Linear", "LoraModel"] diff --git a/mindnlp/peft/tuners/lora/config.py b/mindnlp/peft/tuners/lora/config.py new file mode 100644 index 000000000..2b774d554 --- /dev/null +++ b/mindnlp/peft/tuners/lora/config.py @@ -0,0 +1,300 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""lora config""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Literal, Optional, Union + +from ...config import PeftConfig +from ...utils import PeftType + + +@dataclass +class LoftQConfig: + """ + This is the sub-configuration class to store the configuration of a [`LoraModel`]. + + Args: + bits_pattern (`dict`): The mapping from layer names or regexp expression to bits which are different from the + default bits specified by `bits`. For example, `{model.decoder.layers.0.encoder_attn.k_proj: 2`}. + bits (`int`): Quantization bits for LoftQ. + iter (`int`): Alternating iterations for LoftQ. + fake (`bool`): True: use fp16/fp32; used for first time to save weights. False: use bitsandbytes 4bit linear + models. weights can't be saved. Recommend to set to True, save the weights and load the saved weights in 4 + bits. + """ + + loftq_bits: int = field(default=4, metadata={"help": "Quantization bits for LoftQ"}) + loftq_iter: int = field(default=1, metadata={"help": "Alternating iterations for LoftQ"}) + + +@dataclass +class LoraConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`LoraModel`]. + + Args: + r (`int`): + Lora attention dimension (the "rank"). + target_modules (`Optional[Union[List[str], str]]`): + The names of the modules to apply the adapter to. If this is specified, only the modules with the specified + names will be replaced. When passing a string, a regex match will be performed. When passing a list of + strings, either an exact match will be performed or it is checked if the name of the module ends with any + of the passed strings. If this is specified as 'all-linear', then all linear/Conv1D modules are chosen, + excluding the output layer. If this is not specified, modules will be chosen according to the model + architecture. If the architecture is not known, an error will be raised -- in this case, you should specify + the target modules manually. + lora_alpha (`int`): + The alpha parameter for Lora scaling. + lora_dropout (`float`): + The dropout probability for Lora layers. + fan_in_fan_out (`bool`): + Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses + `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`. + bias (`str`): + Bias type for LoRA. Can be 'none', 'all' or 'lora_only'. If 'all' or 'lora_only', the corresponding biases + will be updated during training. Be aware that this means that, even when disabling the adapters, the model + will not produce the same output as the base model would have without adaptation. + use_rslora (`bool`): + When set to True, uses Rank-Stabilized LoRA which + sets the adapter scaling factor to `lora_alpha/math.sqrt(r)`, since it was proven to work better. + Otherwise, it will use the original default value of `lora_alpha/r`. + modules_to_save (`List[str]`): + List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint. + init_lora_weights (`bool` | `Literal["gaussian", "loftq"]`): + How to initialize the weights of the adapter layers. Passing True (default) results in the default + initialization from the reference implementation from Microsoft. Passing 'gaussian' results in Gaussian + initialization scaled by the LoRA rank for linear and layers. Setting the initialization to False leads to + completely random initialization and is discouraged. Pass `'loftq'` to use LoftQ initialization. + layers_to_transform (`Union[List[int], int]`): + The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices + that are specified in this list. If a single integer is passed, it will apply the transformations on the + layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None`. + rank_pattern (`dict`): + The mapping from layer names or regexp expression to ranks which are different from the default rank + specified by `r`. + alpha_pattern (`dict`): + The mapping from layer names or regexp expression to alphas which are different from the default alpha + specified by `lora_alpha`. + megatron_config (`Optional[dict]`): + The TransformerConfig arguments for Megatron. It is used to create LoRA's parallel linear layer. You can + get it like this, `core_transformer_config_from_args(get_args())`, these two functions being from Megatron. + The arguments will be used to initialize the TransformerConfig of Megatron. You need to specify this + parameter when you want to apply LoRA to the ColumnParallelLinear and RowParallelLinear layers of megatron. + megatron_core (`Optional[str]`): + The core module from Megatron to use, defaults to `"megatron.core"`. + loftq_config (`Optional[LoftQConfig]`): + The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone weights + and initialize Lora layers. Also pass `init_lora_weights='loftq'`. Note that you should not pass a + quantized model in this case, as LoftQ will quantize the model itself. + use_dora (`bool`): + Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the weights + into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is + handled by a separate learnable parameter. This can improve the performance of LoRA especially at low + ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger overhead than pure + LoRA, so it is recommended to merge weights for inference. For more information, see + https://arxiv.org/abs/2402.09353. + layer_replication (`List[Tuple[int, int]]`): + Build a new stack of layers by stacking the original model layers according to the ranges specified. This + allows expanding (or shrinking) the model without duplicating the base model weights. The new layers will + all have separate LoRA adapters attached to them. + """ + + r: int = field(default=8, metadata={"help": "Lora attention dimension"}) + target_modules: Optional[Union[list[str], str]] = field( + default=None, + metadata={ + "help": ( + "List of module names or regex expression of the module names to replace with LoRA." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'." + "This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer." + "If not specified, modules will be chosen according to the model architecture, If the architecture is " + "not known, an error will be raised -- in this case, you should specify the target modules manually." + ), + }, + ) + lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"}) + lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"}) + fan_in_fan_out: bool = field( + default=False, + metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, + ) + bias: Literal["none", "all", "lora_only"] = field( + default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"} + ) + use_rslora: bool = field( + default=False, + metadata={ + "help": ( + "When set to True, uses Rank-Stabilized LoRA doi.org/10.48550/arXiv.2312.03732" + " which sets the adapter scaling factor to `lora_alpha/math.sqrt(r)`, since it" + " was proven to work better. Otherwise, it will use the original default" + " value of `lora_alpha/r`." + ) + }, + ) + modules_to_save: Optional[list[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, " + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." + }, + ) + init_lora_weights: bool | Literal["gaussian", "loftq"] = field( + default=True, + metadata={ + "help": ( + "How to initialize the weights of the LoRA layers. Passing True (default) results in the default " + "initialization from the reference implementation from Microsoft. Passing 'gaussian' results " + "in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization " + "to False leads to completely random initialization and is discouraged." + "Pass `'loftq'` to use LoftQ initialization" + ), + }, + ) + layers_to_transform: Optional[Union[list[int], int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. " + "If a single integer is passed, PEFT will transform only the layer at this index. " + "This only works when target_modules is a list of str." + }, + ) + layers_pattern: Optional[Union[list[str], str]] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." + "This only works when target_modules is a list of str." + }, + ) + rank_pattern: Optional[dict] = field( + default_factory=dict, + metadata={ + "help": ( + "The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. " + "For example, `{model.decoder.layers.0.encoder_attn.k_proj: 8`}" + ) + }, + ) + alpha_pattern: Optional[dict] = field( + default_factory=dict, + metadata={ + "help": ( + "The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `lora_alpha`. " + "For example, `{model.decoder.layers.0.encoder_attn.k_proj: 32`}" + ) + }, + ) + megatron_config: Optional[dict] = field( + default=None, + metadata={ + "help": ( + "The TransformerConfig from Megatron. It is used to create LoRA's parallel linear layer." + "You can get it like this, `core_transformer_config_from_args(get_args())`, " + "these two functions being from Megatron." + "You need to specify this parameter when you want to apply LoRA to the ColumnParallelLinear and " + "RowParallelLinear layers of megatron." + "It should be noted that we may not be able to use the `save_pretrained` and `from_pretrained` " + "functions, because TransformerConfig may not necessarily be serialized." + "But when using megatron, we can use `get_peft_model_state_dict` function and " + "megatron's framework, they can also save and load models and configurations." + ) + }, + ) + megatron_core: Optional[str] = field( + default="megatron.core", + metadata={ + "help": ( + "The core module from Megatron, it is used to create LoRA's parallel linear layer. " + "It only needs to be passed in when you need to use your own modified megatron core module. " + "Otherwise, it will use the default value `megatron.core`. " + ) + }, + ) + # dict type is used when loading config.json + loftq_config: Union[LoftQConfig, dict] = field( + default_factory=dict, + metadata={ + "help": ( + "The configuration of LoftQ. If this is passed, then LoftQ will be used to quantize the backbone " + "weights and initialize Lora layers. Also set `init_lora_weights='loftq'` in this case." + ) + }, + ) + use_dora: bool = field( + default=False, + metadata={ + "help": ( + "Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the " + "weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the " + "magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, " + "especially at low ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger" + "overhead than pure LoRA, so it is recommended to merge weights for inference. For more information, " + "see https://arxiv.org/abs/2402.09353." + ) + }, + ) + # Enables replicating layers in a model to expand it to a larger model. + layer_replication: Optional[list[tuple[int, int]]] = field( + default=None, + metadata={ + "help": ( + "This enables using LoRA to effectively expand a transformer model to a larger size by repeating some layers. " + "The transformation handles models (currently Llama, Bert or Falcon compatible architectures) with " + "a module list in the model which it modifies to expand the number of modules. " + "Base weights are shared so the memory usage is close to the original model. The intended use is these base weights " + "remain fixed during finetuning but each layer has a separate LoRA adapter so the layers can be specialed via " + "the adapter layers fit during fine tuning." + "The format is a list of [start, end) pairs which specify the layer ranges to stack. For example:\n" + " Original model has 5 layers labelled by their position in the model: `[0, 1, 2, 3, 4]`\n" + " layer_replication: `[[0, 4], [2, 5]]`\n" + " Final model will have this arrangement of original layers: `[0, 1, 2, 3, 2, 3, 4]`\n" + "This format is based on what is used for pass-through merges in mergekit. It makes it simple to select sequential " + "ranges of a model and stack them while reusing layers at either end of each sequence." + ) + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.LORA + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) + # if target_modules is a regex expression, then layers_to_transform should be None + if isinstance(self.target_modules, str) and self.layers_to_transform is not None: + raise ValueError("`layers_to_transform` cannot be used when `target_modules` is a str.") + + # if target_modules is a regex expression, then layers_pattern should be None + if isinstance(self.target_modules, str) and self.layers_pattern is not None: + raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") + + if self.use_dora and self.megatron_config: + raise ValueError("DoRA does not support megatron_core, please set `use_dora=False`.") + + # handle init_lora_weights and loftq_config + if self.init_lora_weights == "loftq": + import importlib + + if not importlib.util.find_spec("scipy"): + raise ImportError("The required package 'scipy' is not installed. Please install it to continue.") + if self.loftq_config is None: + raise ValueError("`loftq_config` must be specified when `init_lora_weights` is 'loftq'.") + + # convert loftq_config to dict + if self.loftq_config and not isinstance(self.loftq_config, dict): + self.loftq_config = vars(self.loftq_config) diff --git a/mindnlp/peft/tuners/lora/layer.py b/mindnlp/peft/tuners/lora/layer.py new file mode 100644 index 000000000..d3948bec7 --- /dev/null +++ b/mindnlp/peft/tuners/lora/layer.py @@ -0,0 +1,955 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""lora layer""" +from __future__ import annotations + +import math +import warnings +from typing import Any, Optional, Union + +import mindspore +from mindspore import nn, ops, Parameter +from ....transformers.ms_utils import Conv1D +from ...._legacy.abc import ParameterDict +from ....modules.functional import normalize, embedding +from ..tuners_utils import BaseTunerLayer, check_adapters_to_merge +from ...utils.other import transpose + +from .config import LoraConfig + + +class LoraLayer(BaseTunerLayer): + # All names of layers that may contain (trainable) adapter weights + adapter_layer_names = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B") + # All names of other parameters that may contain adapter-related parameters + other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout") + + def __init__(self, base_layer: nn.Cell, **kwargs) -> None: + self.base_layer = base_layer + self.r = {} + self.lora_alpha = {} + self.scaling = {} + self.lora_dropout = nn.CellDict({}) + self.lora_A = nn.CellDict({}) + self.lora_B = nn.CellDict({}) + # For Embedding layer + self.lora_embedding_A = ParameterDict({}) + self.lora_embedding_B = ParameterDict({}) + # Mark the weight as unmerged + self._disable_adapters = False + self.merged_adapters = [] + self.use_dora: dict[str, bool] = {} + self.lora_magnitude_vector: Optional[ParameterDict] = None # for DoRA + self._caches: dict[str, Any] = {} + self.kwargs = kwargs + + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Dense): + in_features, out_features = base_layer.in_channels, base_layer.out_channels + elif isinstance(base_layer, nn.Conv2d): + in_features, out_features = base_layer.in_channels, base_layer.out_channels + elif isinstance(base_layer, nn.Embedding): + in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim + elif isinstance(base_layer, Conv1D): + in_features, out_features = ( + base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape + ) + elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"): + # QuantLinear + in_features, out_features = base_layer.infeatures, base_layer.outfeatures + elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): + # Megatron ColumnParallelLinear,RowParallelLinear + in_features, out_features = base_layer.input_size, base_layer.output_size + elif hasattr(base_layer, "codebooks") and base_layer.__class__.__name__ == "QuantizedLinear": + # AQLM QuantLinear + in_features, out_features = base_layer.in_features, base_layer.out_features + elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM": + # Awq layers + in_features, out_features = base_layer.in_features, base_layer.out_features + elif base_layer.__class__.__name__ == "EetqLinear": + # Eetq layers + in_features, out_features = base_layer.in_features, base_layer.out_features + elif hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear": + # HQQ layers + in_features, out_features = base_layer.in_features, base_layer.out_features + else: + raise ValueError(f"Unsupported layer type {type(base_layer)}") + + self.in_features = in_features + self.out_features = out_features + + def update_layer( + self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora: bool = False + ): + # This code works for linear layers, override for other layer types + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout.update(nn.CellDict({adapter_name: lora_dropout_layer})) + # Actual trainable parameters + self.lora_A[adapter_name] = nn.Dense(self.in_features, r, has_bias=False) + self.lora_B[adapter_name] = nn.Dense(r, self.out_features, has_bias=False) + if use_rslora: + self.scaling[adapter_name] = lora_alpha / math.sqrt(r) + else: + self.scaling[adapter_name] = lora_alpha / r + + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: + self.reset_lora_parameters(adapter_name, init_lora_weights) + + # check weight and qweight (for GPTQ) + for weight_name in ("weight", "qweight"): + weight = getattr(self.get_base_layer(), weight_name, None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + break + + if use_dora: + self.dora_init(adapter_name) + self.use_dora[adapter_name] = True + else: + self.use_dora[adapter_name] = False + + self.set_adapter(self.active_adapters) + + def reset_lora_parameters(self, adapter_name, init_lora_weights): + if init_lora_weights is False: + return + + if adapter_name in self.lora_A.keys(): + if init_lora_weights is True: + # initialize A the same way as the default for nn.Dense and B to zero + # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124 + nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5)) + elif init_lora_weights.lower() == "gaussian": + nn.init.normal_(self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name]) + else: + raise ValueError(f"Unknown initialization {init_lora_weights=}") + nn.init.zeros_(self.lora_B[adapter_name].weight) + if adapter_name in self.lora_embedding_A.keys(): + # initialize a the same way as the default for nn.Dense and b to zero + nn.init.zeros_(self.lora_embedding_A[adapter_name]) + nn.init.normal_(self.lora_embedding_B[adapter_name]) + + def _get_weight_norm(self, weight, lora_weight, scaling) -> mindspore.Tensor: + # calculate L2 norm of weight matrix, column-wise + weight = transpose(weight, self.fan_in_fan_out) + weight = weight + scaling * lora_weight + weight_norm = normalize(weight, dim=1).to(weight.dtype) + return weight_norm + + def _cache_store(self, key: str, value: Any) -> None: + self._caches[key] = value + + def _cache_pop(self, key: str) -> Any: + value = self._caches.pop(key) + return value + + def set_scale(self, adapter, scale): + if adapter not in self.scaling: + # Ignore the case where the adapter is not in the layer + return + self.scaling[adapter] = scale * self.lora_alpha[adapter] / self.r[adapter] + + def scale_layer(self, scale: float) -> None: + if scale == 1: + return + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + + self.scaling[active_adapter] *= scale + + def unscale_layer(self, scale=None) -> None: + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + + if scale is None: + self.scaling[active_adapter] = self.lora_alpha[active_adapter] / self.r[active_adapter] + else: + self.scaling[active_adapter] /= scale + + def _check_forward_args(self, x, *args, **kwargs): + """Check if the arguments are compatible with the configs and state of the model""" + adapter_names = kwargs.get("adapter_names", None) + if adapter_names is None: + return + + if len(x) != len(adapter_names): + msg = ( + "Length of `adapter_names` should be the same as the number of inputs, but got " + f"{len(adapter_names)} and {len(x)} respectively." + ) + raise ValueError(msg) + + if self.merged: + # It is unclear what would be the right thing to do if users pass adapter_names and there are merged + # adapters. Therefore, it is better to raise an error in this case. + msg = "Cannot pass `adapter_names` when there are merged adapters, please call `unmerge_adapter` first." + raise ValueError(msg) + + unique_adapters = set(self.active_adapters) + for adapter_name in unique_adapters: + if self.use_dora.get(adapter_name, False): + msg = "Cannot pass `adapter_names` when DoRA is enabled." + raise ValueError(msg) + + def _mixed_batch_forward( + self, x: mindspore.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any + ) -> mindspore.Tensor: + # This is a special method that handles the case when users pass the argument `adapter_names`. This is an + # extra argument that allows mixing different adapters in the same batch at inference time. + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + + unique_adapters = set(adapter_names) + sub_batch_indices_list = [] + for adapter in unique_adapters: + sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + + for i, active_adapter in enumerate(unique_adapters): + if active_adapter == "__base__": + continue + if active_adapter not in self.lora_A.keys(): + continue + + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear + # layer output + sub_batch = x[sub_batch_indices_list[i]].to(lora_A.weight.dtype) + lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling + result[sub_batch_indices_list[i]] += lora_output.to(torch_result_dtype) + + return result + + +# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# and modified to work with PyTorch FSDP + + +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + + +class Linear(nn.Cell, LoraLayer): + # Lora implemented in a dense layer + def __init__( + self, + base_layer, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + is_target_conv_1d_layer: bool = False, + init_lora_weights: Union[bool, str] = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ) -> None: + super().__init__() + LoraLayer.__init__(self, base_layer, **kwargs) + self.fan_in_fan_out = fan_in_fan_out + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) + self.is_target_conv_1d_layer = is_target_conv_1d_layer + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.lora_A.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weights = base_layer.weight.data.clone() + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + orig_weights = orig_weights + delta_weight + else: + # handle dora + # since delta_weight already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm( + orig_weights, transpose(delta_weight, self.fan_in_fan_out), scaling=1 + ) + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out) + orig_weights = dora_factor * (orig_weights + delta_weight) + + if not ops.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = orig_weights + else: + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + base_layer.weight.data = base_layer.weight.data + delta_weight + else: + # handle dora + # since delta_weight already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm( + base_layer.weight, transpose(delta_weight, self.fan_in_fan_out), scaling=1 + ) + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out) + new_weight = dora_factor * (base_layer.weight.data + delta_weight) + base_layer.weight.data = new_weight + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.lora_A.keys(): + weight = self.get_base_layer().weight + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + weight.data -= delta_weight + else: + weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + weight_orig = weight.data / dora_factor.view(-1, 1) - delta_weight + weight.data = weight_orig + + def get_delta_weight(self, adapter) -> mindspore.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + device = self.lora_B[adapter].weight.device + dtype = self.lora_B[adapter].weight.dtype + + + weight_A = self.lora_A[adapter].weight + weight_B = self.lora_B[adapter].weight + + output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter] + + return output_tensor + + def construct(self, x: mindspore.Tensor, *args: Any, **kwargs: Any) -> mindspore.Tensor: + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) + + if not self.use_dora[active_adapter]: + result = result + lora_B(lora_A(dropout(x))) * scaling + else: + x = dropout(x) + result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) + + result = result.to(torch_result_dtype) + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +class Embedding(nn.Cell, LoraLayer): + # LoRA implemented in a Embedding layer + def __init__( + self, + base_layer: nn.Cell, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: Union[bool, str] = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ) -> None: + super().__init__() + LoraLayer.__init__(self, base_layer) + + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) + + def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout[adapter_name] = lora_dropout_layer + # Actual trainable parameters + weight_A = ops.randn((r, self.in_features)) + weight_B = ops.randn((self.out_features, r)) + self.lora_embedding_A[adapter_name] = Parameter(weight_A) + self.lora_embedding_B[adapter_name] = Parameter(weight_B) + if use_rslora: + self.scaling[adapter_name] = lora_alpha / math.sqrt(r) + else: + self.scaling[adapter_name] = lora_alpha / r + + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: + self.reset_lora_parameters(adapter_name, init_lora_weights) + + base_layer = self.get_base_layer() + weight = getattr(base_layer, "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + self.to(base_layer.weight.device, dtype=weight.dtype) + + self.set_adapter(self.active_adapters) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.lora_embedding_A.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weights = base_layer.weight.data.clone() + orig_weights = orig_weights + self.get_delta_weight(active_adapter) + + if not ops.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = orig_weights + else: + base_layer.weight.data = base_layer.weight.data + self.get_delta_weight(active_adapter) + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.lora_embedding_A.keys(): + self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + + def get_delta_weight(self, adapter) -> mindspore.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + weight_A = self.lora_embedding_A[adapter] + weight_B = self.lora_embedding_B[adapter] + + output_tensor = transpose(weight_B @ weight_A, True) * self.scaling[adapter] + + return output_tensor + + def _mixed_batch_forward( + self, x: mindspore.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any + ) -> mindspore.Tensor: + # This is a special method that handles the case when users pass the argument `adapter_names`. This is an + # extra argument that allows mixing different adapters in the same batch at inference time. + result = self.base_layer(x, *args, **kwargs) + + unique_adapters = set(adapter_names) + sub_batch_indices_list = [] + for adapter in unique_adapters: + sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + + for i, active_adapter in enumerate(unique_adapters): + if active_adapter == "__base__": + continue + if active_adapter not in self.lora_embedding_A.keys(): + continue + + embedding_A = self.lora_embedding_A[active_adapter].T + embedding_B = self.lora_embedding_B[active_adapter].T + scaling = self.scaling[active_adapter] + + # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear + # layer output + sub_batch = x[sub_batch_indices_list[i]] + after_A = self._embed(sub_batch, embedding_A) + result[sub_batch_indices_list[i]] += (after_A @ embedding_B) * scaling + + return result + + def _embed(self, input: mindspore.Tensor, weight: mindspore.Tensor) -> mindspore.Tensor: + # base_layer = self.get_base_layer() + return embedding( + input, + weight, + # padding_idx=base_layer.padding_idx, + # max_norm=base_layer.max_norm, + # norm_type=base_layer.norm_type, + # scale_grad_by_freq=base_layer.scale_grad_by_freq, + # sparse=base_layer.sparse, + ) + + def construct(self, x: mindspore.Tensor, *args: Any, **kwargs: Any) -> mindspore.Tensor: + # TODO: no dtype conversion here, unlike in Linear, is that correct? + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_embedding_A: + continue + embedding_A = self.lora_embedding_A[active_adapter].T + embedding_B = self.lora_embedding_B[active_adapter].T + scaling = self.scaling[active_adapter] + after_A = self._embed(x, embedding_A) + result = result + (after_A @ embedding_B) * scaling + result = result.to(torch_result_dtype) + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +class Conv2d(nn.Cell, LoraLayer): + # Lora implemented in a conv2d layer + def __init__( + self, + base_layer: nn.Cell, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: Union[bool, str] = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ) -> None: + super().__init__() + LoraLayer.__init__(self, base_layer) + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) + + def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout[adapter_name] = lora_dropout_layer + # Actual trainable parameters + base_layer = self.get_base_layer() + kernel_size = base_layer.kernel_size + stride = base_layer.stride + padding = base_layer.padding + self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, has_bias=False) + self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), has_bias=False) + if use_rslora: + self.scaling[adapter_name] = lora_alpha / math.sqrt(r) + else: + self.scaling[adapter_name] = lora_alpha / r + + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: + self.reset_lora_parameters(adapter_name, init_lora_weights) + + weight = getattr(base_layer, "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + self.to(base_layer.weight.device, dtype=weight.dtype) + + if use_dora: + self.dora_init(adapter_name) + self.use_dora[adapter_name] = True + else: + self.use_dora[adapter_name] = False + + self.set_adapter(self.active_adapters) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights inside the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.lora_A.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weights = base_layer.weight.data.clone() + delta_weight = self.get_delta_weight(active_adapter) + + if not self.use_dora[active_adapter]: + orig_weights = orig_weights + delta_weight + else: + # handle dora + # since delta_weight already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(orig_weights, delta_weight, scaling=1) + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + orig_weights = dora_factor.view(-1, 1, 1, 1) * (orig_weights + delta_weight) + + if not ops.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + base_layer.weight.data = orig_weights + else: + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + base_layer.weight.data = base_layer.weight.data + delta_weight + else: + # handle dora + # since delta_weight already includes scaling, set it to 1 here + weight_norm = self._get_weight_norm(base_layer.weight, delta_weight, scaling=1) + # We need to cache weight_norm because it has to be based on the original weights. We + # cannot calculate it on the fly based on the merged weights when unmerging because its a + # different value + self._cache_store(f"{active_adapter}-weight_norm", weight_norm) + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + new_weight = dora_factor.view(-1, 1, 1, 1) * (base_layer.weight.data + delta_weight) + base_layer.weight.data = new_weight + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.lora_A.keys(): + weight = self.get_base_layer().weight + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + weight.data -= delta_weight + else: + weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") + dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + weight_orig = weight.data / dora_factor.view(-1, 1, 1, 1) - delta_weight + weight.data = weight_orig + + def get_delta_weight(self, adapter) -> mindspore.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + device = self.lora_B[adapter].weight.device + dtype = self.lora_A[adapter].weight.dtype + + # In case users wants to merge the adapter weights that are in + # float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16. + weight_A = self.lora_A[adapter].weight + weight_B = self.lora_B[adapter].weight + + # https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117 + if self.get_base_layer().weight.shape[2:4] == (1, 1): + # conv2d 1x1 + output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze( + 3 + ) * self.scaling[adapter] + else: + # conv2d 3x3 + output_tensor = ( + ops.conv2d( + weight_A.permute(1, 0, 2, 3), + weight_B, + ).permute(1, 0, 2, 3) + * self.scaling[adapter] + ) + + return output_tensor + + def _get_weight_norm(self, weight, lora_weight, scaling) -> mindspore.Tensor: + # calculate L2 norm of weight matrix, channel-wise + weight = weight + scaling * lora_weight + # the following is needed to have compatibility with the 4D weight tensors of Conv2D + weight_norm = weight.norm(p=2, dim=(1, 2, 3), keepdim=True).swapaxes(1, 0) + return weight_norm + + def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter): + """ + For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer + output. + """ + base_layer = self.get_base_layer() + weight = base_layer.weight + lora_weight = ops.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)) + lora_weight = lora_weight.reshape(weight.shape) + magnitude = self.lora_magnitude_vector[active_adapter] + weight_norm = self._get_weight_norm(weight, lora_weight, scaling) + # see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) + # "[...] we suggest treating ||V +∆V ||_c in + # Eq. (5) as a constant, thereby detaching it from the gradient + # graph. This means that while ||V + ∆V ||_c dynamically + # reflects the updates of ∆V , it won’t receive any gradient + # during backpropagation" + mag_norm_scale = magnitude / weight_norm + result_dora = (mag_norm_scale - 1) * ( + ops.conv2d( + x, + weight, + bias=None, + stride=base_layer.stride, + padding=base_layer.padding, + dilation=base_layer.dilation, + groups=base_layer.groups, + ) + ) + mag_norm_scale * lora_B(lora_A(x)) * scaling + + return result_dora + + def construct(self, x: mindspore.Tensor, *args, **kwargs) -> mindspore.Tensor: + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) + + if not self.use_dora[active_adapter]: + result = result + lora_B(lora_A(dropout(x))) * scaling + else: + x = dropout(x) + result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) + + result = result.to(torch_result_dtype) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +def dispatch_default( + target: nn.Cell, + adapter_name: str, + lora_config: LoraConfig, + **kwargs, +) -> Optional[nn.Cell]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, nn.Embedding): + embedding_kwargs = kwargs.copy() + embedding_kwargs.pop("fan_in_fan_out", None) + embedding_kwargs.update(lora_config.loftq_config) + new_module = Embedding(target, adapter_name, **embedding_kwargs) + elif isinstance(target_base_layer, nn.Conv2d): + kwargs.update(lora_config.loftq_config) + new_module = Conv2d(target, adapter_name, **kwargs) + elif isinstance(target_base_layer, nn.Dense): + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Dense`. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False + kwargs.update(lora_config.loftq_config) + new_module = Linear(target, adapter_name, **kwargs) + elif isinstance(target_base_layer, Conv1D): + if not kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True." + ) + kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True + kwargs.update(lora_config.loftq_config) + new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs) + + return new_module diff --git a/mindnlp/peft/tuners/lora/model.py b/mindnlp/peft/tuners/lora/model.py new file mode 100644 index 000000000..b35ae355a --- /dev/null +++ b/mindnlp/peft/tuners/lora/model.py @@ -0,0 +1,807 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""lora model""" +from __future__ import annotations + +import math +import operator +import re +import warnings +from contextlib import contextmanager +from dataclasses import asdict, replace +from enum import Enum +from functools import partial, reduce +from itertools import chain +from typing import Literal, Optional + +import mindspore +from mindspore import nn, ops +from tqdm import tqdm + +from ..tuners_utils import ( + BaseTuner, + BaseTunerLayer, + check_target_module_exists, + # onload_layer, + replicate_layers, +) +from ...utils import ( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + _freeze_adapter, + _get_submodules, +) +from ...utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties + +from .config import LoraConfig +from .layer import Conv2d, LoraLayer, dispatch_default + + +def _adapter_names_pre_forward_hook(target, args, kwargs, adapter_names): + # pre-forward hook to inject the adapter_names argument when using mixed adapter batches inference + kwargs["adapter_names"] = adapter_names + return args, kwargs + + +class LoraModel(BaseTuner): + """ + Creates Low Rank Adapter (LoRA) model from a pretrained transformers model. + + The method is described in detail in https://arxiv.org/abs/2106.09685. + + Args: + model ([`nn.Cell`]): The model to be adapted. + config ([`LoraConfig`]): The configuration of the Lora model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `nn.Cell`: The Lora model. + + Example: + + ```py + >>> from transformers import AutoModelForSeq2SeqLM + >>> from peft import LoraModel, LoraConfig + + >>> config = LoraConfig( + ... task_type="SEQ_2_SEQ_LM", + ... r=8, + ... lora_alpha=32, + ... target_modules=["q", "v"], + ... lora_dropout=0.01, + ... ) + + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + >>> lora_model = LoraModel(model, config, "default") + ``` + + ```py + >>> import torch + >>> import transformers + >>> from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training + + >>> rank = ... + >>> target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc_in", "fc_out", "wte"] + >>> config = LoraConfig( + ... r=4, lora_alpha=16, target_modules=target_modules, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM" + ... ) + >>> quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True) + + >>> tokenizer = transformers.AutoTokenizer.from_pretrained( + ... "kakaobrain/kogpt", + ... revision="KoGPT6B-ryan1.5b-float16", # or float32 version: revision=KoGPT6B-ryan1.5b + ... bos_token="[BOS]", + ... eos_token="[EOS]", + ... unk_token="[UNK]", + ... pad_token="[PAD]", + ... mask_token="[MASK]", + ... ) + >>> model = transformers.GPTJForCausalLM.from_pretrained( + ... "kakaobrain/kogpt", + ... revision="KoGPT6B-ryan1.5b-float16", # or float32 version: revision=KoGPT6B-ryan1.5b + ... pad_token_id=tokenizer.eos_token_id, + ... use_cache=False, + ... device_map={"": rank}, + ... torch_dtype=torch.float16, + ... quantization_config=quantization_config, + ... ) + >>> model = prepare_model_for_kbit_training(model) + >>> lora_model = get_peft_model(model, config) + ``` + + **Attributes**: + - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`LoraConfig`]): The configuration of the Lora model. + """ + + prefix: str = "lora_" + + def _check_new_adapter_config(self, config: LoraConfig) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + # TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check + # does not fully correspond to the error message. + if (len(self.peft_config) > 1) and (config.bias != "none"): + raise ValueError( + f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " + "set bias to 'none' for all adapters." + ) + + @staticmethod + def _check_target_module_exists(lora_config, key): + return check_target_module_exists(lora_config, key) + + def _prepare_model(self, peft_config: LoraConfig, model: nn.Module): + r""" + A private method to modify the model structure before adapter is applied. + + Args: + peft_config (`PeftConfig`): + The prepared adapter config. + model (`nn.Module`): + The model that is going to be adapted. + """ + if peft_config.layer_replication: + replicate_layers(model, peft_config.layer_replication) + + def _create_and_replace( + self, + lora_config, + adapter_name, + target, + target_name, + parent, + current_key, + ): + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + + # Regexp matching - Find key which matches current target_name in patterns provided + pattern_keys = list(chain(lora_config.rank_pattern.keys(), lora_config.alpha_pattern.keys())) + target_name_key = next(filter(lambda key: re.match(rf".*\.{key}$", current_key), pattern_keys), current_key) + r = lora_config.rank_pattern.get(target_name_key, lora_config.r) + alpha = lora_config.alpha_pattern.get(target_name_key, lora_config.lora_alpha) + + kwargs = { + "r": r, + "lora_alpha": alpha, + "lora_dropout": lora_config.lora_dropout, + "fan_in_fan_out": lora_config.fan_in_fan_out, + "init_lora_weights": lora_config.init_lora_weights, + "use_rslora": lora_config.use_rslora, + "use_dora": lora_config.use_dora, + "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), + "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), + } + + # quant_methods = ["gptq", "aqlm", "awq"] + # for quant_method in quant_methods: + # quantization_config = get_quantization_config(self.model, method=quant_method) + # if quantization_config is not None: + # kwargs[f"{quant_method}_quantization_config"] = quantization_config + + # note: AdaLoraLayer is a subclass of LoraLayer, we need to exclude it + from ..adalora import AdaLoraLayer + + if isinstance(target, LoraLayer) and not isinstance(target, AdaLoraLayer): + target.update_layer( + adapter_name, + r, + lora_alpha=alpha, + lora_dropout=lora_config.lora_dropout, + init_lora_weights=lora_config.init_lora_weights, + use_rslora=lora_config.use_rslora, + use_dora=lora_config.use_dora, + ) + else: + new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) + if adapter_name not in self.active_adapters: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad = False + self._replace_module(parent, target_name, new_module, target) + + def _replace_module(self, parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + # dispatch to correct device + for name, module in new_module.cells_and_names(): + if (self.prefix in name) or ("ranknum" in name): + weight = ( + child.qweight + if hasattr(child, "qweight") + else child.W_q + if hasattr(child, "W_q") + else child.weight + ) + module.to(weight.device) + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.parameters_and_names(): + if self.prefix not in n: + p.requires_grad = False + + for active_adapter in self.active_adapters: + bias = self.peft_config[active_adapter].bias + if bias == "none": + continue + + if bias == "all": + for n, p in model.parameters_and_names(): + if "bias" in n: + p.requires_grad = True + elif bias == "lora_only": + for m in model.modules(): + if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError(f"Requested bias: {bias}, is not implemented.") + + @staticmethod + def _create_new_module(lora_config, adapter_name, target, **kwargs): + # Collect dispatcher functions to decide what backend to use for the replaced LoRA layer. The order matters, + # because the first match is always used. Therefore, the default layers should be checked last. + dispatchers = [dispatch_default] + + new_module = None + for dispatcher in dispatchers: + new_module = dispatcher(target, adapter_name, lora_config=lora_config, **kwargs) + if new_module is not None: # first match wins + break + + if new_module is None: + # no module could be matched + raise ValueError( + f"Target module {target} is not supported. Currently, only the following modules are supported: " + "`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`." + ) + + return new_module + + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.model, name) + + def get_peft_config_as_dict(self, inference: bool = False): + config_dict = {} + for key, value in self.peft_config.items(): + config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} + if inference: + config["inference_mode"] = True + config_dict[key] = config # pylint: disable=undefined-loop-variable + return config + + def _set_adapter_layers(self, enabled: bool = True) -> None: + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self) -> None: + """Enable all adapters. + + Call this if you have previously disabled all adapters and want to re-enable them. + """ + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self) -> None: + """Disable all adapters. + + When disabling all adapters, the model output corresponds to the output of the base model. + """ + for active_adapter in self.active_adapters: + val = self.peft_config[active_adapter].bias + if val != "none": + msg = ( + f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " + "output as the the base model would without adaption." + ) + warnings.warn(msg) + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name: str | list[str]) -> None: + """Set the active adapter(s). + + Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is + not desired, use the following code. + + ```py + >>> for name, param in model_peft.parameters_and_names(): + ... if ...: # some check on name (ex. if 'lora' in name) + ... param.requires_grad = False + ``` + + Args: + adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + """ + for module in self.model.modules(): + if isinstance(module, LoraLayer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + @contextmanager + def _enable_peft_forward_hooks(self, *args, **kwargs): + # If adapter_names is passed as an argument, we inject it into the forward arguments. + adapter_names = kwargs.pop("adapter_names", None) + if adapter_names is None: + # nothing to do + yield + return + + if self.training: + raise ValueError("Cannot pass `adapter_names` when the model is in training mode.") + + hook_handles = [] + for module in self.modules(): + if isinstance(module, LoraLayer): + pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names) + handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) + hook_handles.append(handle) + + yield + + for handle in hook_handles: + handle.remove() + + def _check_merge_allowed(self): + """Verify that the configuration supports merging. + + Currently gptq quantization and replicated layers do not support merging. + """ + if getattr(self.model, "quantization_method", None) == "gptq": + raise ValueError("Cannot merge LORA layers when the model is gptq quantized") + if self.peft_config.get("layer_replication"): + raise ValueError("Cannot merge LORA layers when base model layers are replicated") + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None, + ): + if merge: + self._check_merge_allowed() + + key_list = [key for key, _ in self.model.cells_and_names() if self.prefix not in key] + desc = "Unloading " + ("and merging " if merge else "") + "model" + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + # with onload_layer(target): + # if hasattr(target, "base_layer"): + # if merge: + # target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + # self._replace_module(parent, target_name, target.get_base_layer(), target) + # elif isinstance(target, ModulesToSaveWrapper): + # # save any additional trainable modules part of `modules_to_save` + # new_module = target.modules_to_save[target.active_adapter] + # if hasattr(new_module, "base_layer"): + # # check if the module is itself a tuner layer + # if merge: + # new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names) + # new_module = new_module.get_base_layer() + # setattr(parent, target_name, new_module) + + return self.model + + def _check_add_weighted_adapter( + self, adapters: list[str], combination_type: str, svd_rank: int | None + ) -> tuple[str, int, str]: + """ + Helper function to check if the arguments to add_weighted_adapter are valid and compatible with the underlying + model. + """ + for adapter in adapters: + if adapter not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter} does not exist") + + # If more than one of the adapters targets the same module with modules_to_save, raise an error, as these + # modules cannot be merged. First, find the ModulesToSaveWrapper instances in the model, then check if they + # have modules for the adapters to be merged. + modules_to_save_wrappers = [module for module in self.modules() if isinstance(module, ModulesToSaveWrapper)] + problematic_wrappers = [ + wrapper + for wrapper in modules_to_save_wrappers + if sum(adapter in wrapper.modules_to_save for adapter in adapters) > 1 + ] + if problematic_wrappers: + raise ValueError( + "Cannot add weighted adapters if they target the same module with modules_to_save, but found " + f"{len(problematic_wrappers)} such instance(s)." + ) + + # if there is only one adapter, we can only use linear merging + combination_type = "linear" if len(adapters) == 1 else combination_type + + adapters_ranks = [self.peft_config[adapter].r for adapter in adapters] + if combination_type in ("linear", "ties", "dare_ties", "dare_linear", "magnitude_prune"): + # all adapters ranks should be same, new rank is just this value + if len(set(adapters_ranks)) != 1: + raise ValueError( + "All adapters must have the same r value when using combination_type linear, ties, dare_ties or " + "dare_linear." + ) + new_rank = adapters_ranks[0] + elif combination_type == "cat": + # adapters ranks may be different, new rank is sum of all ranks + # be careful, because output adapter rank may be really big if mixing a lot of adapters + new_rank = sum(adapters_ranks) + elif combination_type.endswith("svd"): + # new rank is the max of all ranks of the adapters if not provided + new_rank = svd_rank or max(adapters_ranks) + else: + raise ValueError(f"Invalid combination_type: {combination_type}") + + target_module_types = [type(self.peft_config[adapter].target_modules) for adapter in adapters] + if not target_module_types: + raise ValueError(f"Found no adapter matching the names in {adapters}") + if len(set(target_module_types)) > 1: + raise ValueError( + "all adapter configs should follow the same target modules type. " + "Combining adapters with `target_modules` type being a mix of list/set and string is not supported." + ) + + if target_module_types[0] == str: + new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters) + elif target_module_types[0] == set: + new_target_modules = reduce( + operator.or_, (self.peft_config[adapter].target_modules for adapter in adapters) + ) + else: + raise TypeError(f"Invalid type {target_module_types[0]} found in target_modules") + + return combination_type, new_rank, new_target_modules + + def add_weighted_adapter( + self, + adapters: list[str], + weights: list[float], + adapter_name: str, + combination_type: str = "svd", + svd_rank: int | None = None, + svd_clamp: int | None = None, + svd_full_matrices: bool = True, + density: float | None = None, + majority_sign_method: Literal["total", "frequency"] = "total", + ) -> None: + """ + This method adds a new adapter by merging the given adapters with the given weights. + + When using the `cat` combination_type you should be aware that rank of the resulting adapter will be equal to + the sum of all adapters ranks. So it's possible that the mixed adapter may become too big and result in OOM + errors. + + Args: + adapters (`list`): + List of adapter names to be merged. + weights (`list`): + List of weights for each adapter. + adapter_name (`str`): + Name of the new adapter. + combination_type (`str`): + The merging type can be one of [`svd`, `linear`, `cat`, `ties`, `ties_svd`, `dare_ties`, `dare_linear`, + `dare_ties_svd`, `dare_linear_svd`, `magnitude_prune`, `magnitude_prune_svd`]. When using the `cat` + combination_type, the rank of the resulting adapter is equal to the sum of all adapters ranks (the + mixed adapter may be too big and result in OOM errors). + svd_rank (`int`, *optional*): + Rank of output adapter for svd. If None provided, will use max rank of merging adapters. + svd_clamp (`float`, *optional*): + A quantile threshold for clamping SVD decomposition output. If None is provided, do not perform + clamping. Defaults to None. + svd_full_matrices (`bool`, *optional*): + Controls whether to compute the full or reduced SVD, and consequently, the shape of the returned + tensors U and Vh. Defaults to True. + density (`float`, *optional*): + Value between 0 and 1. 0 means all values are pruned and 1 means no values are pruned. Should be used + with [`ties`, `ties_svd`, `dare_ties`, `dare_linear`, `dare_ties_svd`, `dare_linear_svd`, + `magnintude_prune`, `magnitude_prune_svd`] + majority_sign_method (`str`): + The method, should be one of ["total", "frequency"], to use to get the magnitude of the sign values. + Should be used with [`ties`, `ties_svd`, `dare_ties`, `dare_ties_svd`] + """ + + if adapter_name in list(self.peft_config.keys()): + return + for adapter in adapters: + if adapter not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter} does not exist") + + combination_type, new_rank, new_target_modules = self._check_add_weighted_adapter( + adapters=adapters, + combination_type=combination_type, + svd_rank=svd_rank, + ) + + self.peft_config[adapter_name] = replace( + self.peft_config[adapters[0]], + r=new_rank, + lora_alpha=new_rank, + target_modules=new_target_modules, + ) + self.inject_adapter(self.model, adapter_name) + + # Do we really need that? + _freeze_adapter(self.model, adapter_name) + + key_list = [key for key, _ in self.model.cells_and_names() if self.prefix not in key] + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, LoraLayer): + if adapter_name in target.lora_A: + target_lora_A = target.lora_A[adapter_name].weight + target_lora_B = target.lora_B[adapter_name].weight + elif adapter_name in target.lora_embedding_A: + target_lora_A = target.lora_embedding_A[adapter_name] + target_lora_B = target.lora_embedding_B[adapter_name] + else: + continue + + target_lora_A.data = target_lora_A.data * 0.0 + target_lora_B.data = target_lora_B.data * 0.0 + if combination_type == "cat": + loras_A, loras_B = [], [] + for adapter, weight in zip(adapters, weights): + if adapter in target.lora_A: + current_adapter_lora_A = target.lora_A[adapter].weight + current_adapter_lora_B = target.lora_B[adapter].weight + elif adapter in target.lora_embedding_A: + current_adapter_lora_A = target.lora_embedding_A[adapter] + current_adapter_lora_B = target.lora_embedding_B[adapter] + else: + continue + loras_A.append(current_adapter_lora_A.data * weight * target.scaling[adapter]) + loras_B.append(current_adapter_lora_B.data) + + if len(loras_A) == 0: + raise ValueError("No matching LoRAs found. Please raise an issue on GitHub.") + loras_A = ops.cat(loras_A, axis=0) + loras_B = ops.cat(loras_B, axis=1) + target_lora_A.data[: loras_A.shape[0], :] = loras_A + target_lora_B.data[:, : loras_B.shape[1]] = loras_B + elif combination_type in [ + "svd", + "ties_svd", + "dare_linear_svd", + "dare_ties_svd", + "magnitude_prune_svd", + ]: + target_lora_A.data, target_lora_B.data = self._svd_generalized_task_arithmetic_weighted_adapter( + combination_type, + adapters, + weights, + new_rank, + target, + target_lora_A, + target_lora_B, + density, + majority_sign_method, + svd_clamp, + full_matrices=svd_full_matrices, + ) + elif combination_type in ["linear", "ties", "dare_linear", "dare_ties", "magnitude_prune"]: + target_lora_A.data, target_lora_B.data = self._generalized_task_arithmetic_weighted_adapter( + combination_type, adapters, weights, target, density, majority_sign_method + ) + + def _svd_generalized_task_arithmetic_weighted_adapter( + self, + combination_type, + adapters, + weights, + new_rank, + target, + target_lora_A, + target_lora_B, + density, + majority_sign_method, + clamp=None, + full_matrices=True, + ): + valid_adapters = [] + valid_weights = [] + is_embedding = any(adapter in target.lora_embedding_A for adapter in adapters) + for adapter, weight in zip(adapters, weights): + if adapter in target.lora_A or adapter in target.lora_embedding_A: + valid_adapters.append(adapter) + valid_weights.append(weight * target.scaling[adapter]) + + # if no valid adapter, nothing to do + if len(valid_adapters) == 0: + raise ValueError("No matching LoRAs found. Please raise an issue on Github.") + delta_weight = [target.get_delta_weight(adapter) for adapter in valid_adapters] + valid_weights = mindspore.tensor(valid_weights) + if combination_type == "svd": + delta_weight = task_arithmetic(delta_weight, valid_weights) + elif combination_type == "ties_svd": + delta_weight = ties(delta_weight, valid_weights, density, majority_sign_method) + elif combination_type == "dare_linear_svd": + delta_weight = dare_linear(delta_weight, valid_weights, density) + elif combination_type == "dare_ties_svd": + delta_weight = dare_ties(delta_weight, valid_weights, density, majority_sign_method) + elif combination_type == "magnitude_prune_svd": + delta_weight = magnitude_prune(delta_weight, valid_weights, density) + else: + raise ValueError(f"Invalid value passed to combination type: {combination_type}") + + conv2d = isinstance(target, Conv2d) + if conv2d: + conv2d_1x1 = target.weight.shape[2:4] == (1, 1) + if not conv2d_1x1: + delta_weight = delta_weight.flatten(start_dim=1) + else: + delta_weight = delta_weight.squeeze() + if (hasattr(target, "fan_in_fan_out") and target.fan_in_fan_out) or is_embedding: + delta_weight = delta_weight.T + + # based on https://github.com/kohya-ss/sd-scripts/blob/main/networks/svd_merge_lora.py#L114-L131 + U, S, Vh = ops.svd(delta_weight, full_matrices=full_matrices) + U = U[:, :new_rank] + S = S[:new_rank] + U = U @ ops.diag(S) + Vh = Vh[:new_rank, :] + if clamp is not None: + dist = ops.cat([U.flatten(), Vh.flatten()]) + hi_val = ops.quantile(dist, clamp) + low_val = -hi_val + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + if conv2d: + U = U.reshape(target_lora_B.data.shape) + Vh = Vh.reshape(target_lora_A.data.shape) + return Vh, U + + def _generalized_task_arithmetic_weighted_adapter( + self, + combination_type, + adapters, + weights, + target, + density, + majority_sign_method, + ): + # account weights for LoRA A and B layers. + valid_weights = [] + lora_A_deltas = [] + lora_B_deltas = [] + for adapter, weight in zip(adapters, weights): + if adapter in target.lora_A: + current_adapter_lora_A = target.lora_A[adapter].weight + current_adapter_lora_B = target.lora_B[adapter].weight + elif adapter in target.lora_embedding_A: + current_adapter_lora_A = target.lora_embedding_A[adapter] + current_adapter_lora_B = target.lora_embedding_B[adapter] + else: + continue + valid_weights.append(math.sqrt(weight * target.scaling[adapter])) + lora_A_deltas.append(current_adapter_lora_A.data) + lora_B_deltas.append(current_adapter_lora_B.data) + valid_weights = mindspore.tensor(valid_weights).to(lora_A_deltas[0].device) + lora_deltas = [lora_A_deltas, lora_B_deltas] + dtype = lora_A_deltas[0].dtype + for i, task_tensors in enumerate(lora_deltas): + if combination_type == "linear": + lora_deltas[i] = task_arithmetic(task_tensors, valid_weights) + elif combination_type == "ties": + lora_deltas[i] = ties(task_tensors, valid_weights, density, majority_sign_method) + elif combination_type == "dare_linear": + lora_deltas[i] = dare_linear(task_tensors, valid_weights, density) + elif combination_type == "dare_ties": + lora_deltas[i] = dare_ties(task_tensors, valid_weights, density, majority_sign_method) + elif combination_type == "magnitude_prune": + lora_deltas[i] = magnitude_prune(task_tensors, valid_weights, density) + else: + raise ValueError("Invalid combination type") + lora_deltas = [delta.to(dtype) for delta in lora_deltas] + return lora_deltas + + def delete_adapter(self, adapter_name: str) -> None: + """ + Deletes an existing adapter. + + Args: + adapter_name (str): Name of the adapter to be deleted. + """ + if adapter_name not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + + key_list = [key for key, _ in self.model.cells_and_names() if self.prefix not in key] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, LoraLayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapters[:] + + self.active_adapter = new_adapter or [] + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> nn.Cell: + r""" + This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model + as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import PeftModel + + >>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b") + >>> peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample" + >>> model = PeftModel.from_pretrained(base_model, peft_model_id) + >>> merged_model = model.merge_and_unload() + ``` + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self) -> nn.Cell: + """ + Gets back the base model by removing all the lora modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) diff --git a/mindnlp/peft/tuners/prompt_tuning/__init__.py b/mindnlp/peft/tuners/prompt_tuning/__init__.py new file mode 100644 index 000000000..320cd9c02 --- /dev/null +++ b/mindnlp/peft/tuners/prompt_tuning/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""prompt tuning""" +from .config import PromptTuningConfig, PromptTuningInit +from .model import PromptEmbedding + + +__all__ = ["PromptTuningConfig", "PromptEmbedding", "PromptTuningInit"] diff --git a/mindnlp/peft/tuners/prompt_tuning/config.py b/mindnlp/peft/tuners/prompt_tuning/config.py new file mode 100644 index 000000000..6aa2e4a85 --- /dev/null +++ b/mindnlp/peft/tuners/prompt_tuning/config.py @@ -0,0 +1,86 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""prompt tuning config.""" +import enum +from dataclasses import dataclass, field +from typing import Optional, Union + +from ...config import PromptLearningConfig +from ...utils import PeftType + + +class PromptTuningInit(str, enum.Enum): + TEXT = "TEXT" + RANDOM = "RANDOM" + + +@dataclass +class PromptTuningConfig(PromptLearningConfig): + """ + This is the configuration class to store the configuration of a [`PromptEmbedding`]. + + Args: + prompt_tuning_init (Union[[`PromptTuningInit`], `str`]): The initialization of the prompt embedding. + prompt_tuning_init_text (`str`, *optional*): + The text to initialize the prompt embedding. Only used if `prompt_tuning_init` is `TEXT`. + tokenizer_name_or_path (`str`, *optional*): + The name or path of the tokenizer. Only used if `prompt_tuning_init` is `TEXT`. + tokenizer_kwargs (`dict`, *optional*): + The keyword arguments to pass to `AutoTokenizer.from_pretrained`. Only used if `prompt_tuning_init` is + `TEXT`. + """ + + prompt_tuning_init: Union[PromptTuningInit, str] = field( + default=PromptTuningInit.RANDOM, + metadata={"help": "How to initialize the prompt tuning parameters"}, + ) + prompt_tuning_init_text: Optional[str] = field( + default=None, + metadata={ + "help": "The text to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" + }, + ) + tokenizer_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The tokenizer to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" + }, + ) + + tokenizer_kwargs: Optional[dict] = field( + default=None, + metadata={ + "help": ( + "The keyword arguments to pass to `AutoTokenizer.from_pretrained`. Only used if prompt_tuning_init is " + "`TEXT`" + ), + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.PROMPT_TUNING + if (self.prompt_tuning_init == PromptTuningInit.TEXT) and not self.tokenizer_name_or_path: + raise ValueError( + f"When prompt_tuning_init='{PromptTuningInit.TEXT.value}', " + f"tokenizer_name_or_path can't be {self.tokenizer_name_or_path}." + ) + if (self.prompt_tuning_init == PromptTuningInit.TEXT) and self.prompt_tuning_init_text is None: + raise ValueError( + f"When prompt_tuning_init='{PromptTuningInit.TEXT.value}', " + f"prompt_tuning_init_text can't be {self.prompt_tuning_init_text}." + ) + if self.tokenizer_kwargs and (self.prompt_tuning_init != PromptTuningInit.TEXT): + raise ValueError( + f"tokenizer_kwargs only valid when using prompt_tuning_init='{PromptTuningInit.TEXT.value}'." + ) diff --git a/mindnlp/peft/tuners/prompt_tuning/model.py b/mindnlp/peft/tuners/prompt_tuning/model.py new file mode 100644 index 000000000..8fb59c854 --- /dev/null +++ b/mindnlp/peft/tuners/prompt_tuning/model.py @@ -0,0 +1,87 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""prompt tuning model""" +import math + +import mindspore +from mindspore import nn, Parameter +from .config import PromptTuningInit + + +class PromptEmbedding(nn.Cell): + """ + The model to encode virtual tokens into prompt embeddings. + + Args: + config ([`PromptTuningConfig`]): The configuration of the prompt embedding. + word_embeddings (`nn.Module`): The word embeddings of the base transformer model. + + **Attributes**: + - **embedding** (`nn.Embedding`) -- The embedding layer of the prompt embedding. + + Example: + + ```py + >>> from peft import PromptEmbedding, PromptTuningConfig + + >>> config = PromptTuningConfig( + ... peft_type="PROMPT_TUNING", + ... task_type="SEQ_2_SEQ_LM", + ... num_virtual_tokens=20, + ... token_dim=768, + ... num_transformer_submodules=1, + ... num_attention_heads=12, + ... num_layers=12, + ... prompt_tuning_init="TEXT", + ... prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral", + ... tokenizer_name_or_path="t5-base", + ... ) + + >>> # t5_model.shared is the word embeddings of the base model + >>> prompt_embedding = PromptEmbedding(config, t5_model.shared) + ``` + + Input Shape: (`batch_size`, `total_virtual_tokens`) + + Output Shape: (`batch_size`, `total_virtual_tokens`, `token_dim`) + """ + + def __init__(self, config, word_embeddings): + super().__init__() + + total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules + self.embedding = nn.Embedding(total_virtual_tokens, config.token_dim) + if config.prompt_tuning_init == PromptTuningInit.TEXT and not config.inference_mode: + from ....transformers import AutoTokenizer + + tokenizer_kwargs = config.tokenizer_kwargs or {} + tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path, **tokenizer_kwargs) + init_text = config.prompt_tuning_init_text + init_token_ids = tokenizer(init_text)["input_ids"] + # Trim or iterate until num_text_tokens matches total_virtual_tokens + num_text_tokens = len(init_token_ids) + if num_text_tokens > total_virtual_tokens: + init_token_ids = init_token_ids[:total_virtual_tokens] + elif num_text_tokens < total_virtual_tokens: + num_reps = math.ceil(total_virtual_tokens / num_text_tokens) + init_token_ids = init_token_ids * num_reps + init_token_ids = init_token_ids[:total_virtual_tokens] + init_token_ids = mindspore.tensor(init_token_ids) + word_embedding_weights = word_embedding_weights.to(mindspore.float32) + self.embedding.weight = Parameter(word_embedding_weights) + + def construct(self, indices): + # Just get embeddings + prompt_embeddings = self.embedding(indices) + return prompt_embeddings diff --git a/mindnlp/peft/tuners/tuners_utils.py b/mindnlp/peft/tuners/tuners_utils.py index cc0796de6..49f83015c 100644 --- a/mindnlp/peft/tuners/tuners_utils.py +++ b/mindnlp/peft/tuners/tuners_utils.py @@ -15,19 +15,14 @@ """ BaseTuner class and BaseTunerLayer class. """ -# pylint: disable=line-too-long -# pylint: disable=missing-function-docstring -# pylint: disable=unused-argument -# pylint: disable=too-many-arguments -# pylint: disable=too-many-function-args -# pylint: disable=no-member -# pylint: disable=function-redefined from __future__ import annotations import re import logging import warnings +import copy from typing import Any, Optional, Union from abc import ABC +from contextlib import contextmanager from mindspore import nn, Tensor from ..config import PeftConfig @@ -36,6 +31,75 @@ logger = logging.getLogger(__name__) +@contextmanager +def onload_layer(layer): + r""" + A utility for modifying a module containing one or more tuners and a base layer, any of which are offloaded to the + CPU or disk. Moves a module's sub-modules to the execution device before some action is performed, after that the + base layer state dictionary is re-assigned (if that layer was offloaded to the disk) and finally the parameters are + offloaded. + + If the module has no offloaded sub-modules, this function does nothing. + + Args: + layer ('mindspore.nn.Cell'): + layer with tuners to be merged + """ + + offloaded_modules = [] + for name, module in layer.cells_and_names(): + if name in ["", "base_layer"]: + continue + # if hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload: + # module._hf_hook.pre_forward(module) + # offloaded_modules.append(module) + + # base_layer_offload = False + # if hasattr(layer, "base_layer") and ( + # hasattr(layer.base_layer, "_hf_hook") + # and isinstance(layer.base_layer._hf_hook, AlignDevicesHook) + # and layer.base_layer._hf_hook.offload + # ): + # # check if the base layer is disk-offloaded (must contain a 'dataset' and an offload index) + # if torch.device("meta") in layer.base_layer._hf_hook.original_devices.values() and hasattr( + # layer.base_layer._hf_hook.weights_map, "dataset" + # ): + # # find the disk-offload index (maps modules to safetensors) from the `dataset` (OffloadedWeightsLoader object) + # index = layer.base_layer._hf_hook.weights_map.dataset.index + # module_name = list(dict(layer.base_layer._hf_hook.weights_map.dataset).keys())[0] # any module will do + # file_name = index[module_name]["safetensors_file"] + # base_name_arr = [] + # # get effective dir name + # for i in os.path.split(file_name): + # if "--" in i: + # base_name_arr.append(i) + # break + # base_name_arr.append(i) + # base_name = os.path.join(*base_name_arr) + # safetensors_filename = base_name + "-merged" + # layer.base_layer._hf_hook.pre_forward(layer.base_layer) + # base_layer_offload = True + + # yield + + # for module in offloaded_modules: + # module._hf_hook.post_forward(module, torch.tensor([])) + + # if base_layer_offload: + # # re-make weights map (must be on cpu to send params to the disk via memmap if disk offload) + # layer.base_layer._hf_hook.weights_map = { + # name: param.to("cpu") for name, param in named_module_tensors(layer.base_layer) + # } + # # offload weights map to disk if original device is the disk + # if torch.device("meta") in layer.base_layer._hf_hook.original_devices.values() and hasattr( + # layer.base_layer._hf_hook.weights_map, "dataset" + # ): + # # rewrite directory with merged weights + # offload_state_dict(safetensors_filename, layer.base_layer._hf_hook.weights_map) + # layer.base_layer._hf_hook.post_forward(layer.base_layer, torch.tensor([])) + + + class BaseTuner(nn.Cell): r""" A base tuner model that provides the common methods and attributes for all tuners that are injectable into a @@ -172,7 +236,7 @@ def _create_and_replace( The optional keyword arguments to pass to deal with particular cases (e.g. 8bit, 4bit quantization) """ - def _mark_only_adapters_as_trainable(self): + def _mark_only_adapters_as_trainable(self, model): r""" A helper method to mark only the adapter layers as trainable (i.e. module.requires_grad = False) This needs to be overriden for all tuner classes to match the correct key names. @@ -290,8 +354,6 @@ class BaseTunerLayer(ABC): The name of the active adapter. """ - active_adapter = None - # All names of layers that may contain adapter (trainable) weights adapter_layer_names: tuple[str] = () # All names of other parameters that may contain adapter-related parameters @@ -529,3 +591,73 @@ def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None: target_module_found = layer_index in layer_indexes return target_module_found + + +def clone_module(module: nn.Cell, share_weights=False): + """Clone a module in a pytorch model. + + Clones a module of a model, optionally sharing all the parameters between the original and the clone. Simplifies + reusing a module when manipulating the architecture of a model. + """ + clone = copy.deepcopy(module) + + def _share_weights(src: nn.Cell, dst: nn.Cell): + for name, param in src.parameters_and_names(expand=False): + setattr(dst, name, param) + + if share_weights: + for name, submodule in module.parameters_and_names(): + _share_weights(submodule, clone.get_submodule(name)) + + return clone + +def replicate_layers(model: nn.Module, layer_map: list[tuple[int, int]]): + """Replicate layers in a transfomer model with weight sharing. + + This function looks for a module list attribute at model[(.model)*].layers and replicates the layers in the module + list according to the layer map. For example the map `[[0, 4], [2, 5]]` will take the set of layers `[0, 1, 2, 3, + 4]` and replace them with a module list containing `[0, 1, 2, 3, 2, 3, 4]`. + """ + while hasattr(model, "model"): + model = model.model + # Some variants of the bert model nest the main model under the bert attribute. + if hasattr(model, "bert"): + model = model.bert + + model_type = None + layers: nn.ModuleList = None + if hasattr(model, "layers"): + model_type = "llama" + layers = model.layers + elif hasattr(model, "encoder") and hasattr(model.encoder, "layer"): + model_type = "bert" + layers = model.encoder.layer + elif hasattr(model, "h"): + model_type = "falcon" + layers = model.h + if not model_type or not isinstance(layers, nn.CellList): + raise ValueError( + "Could not locate the layers attribute in the model. " + "Expected Llama, Bert or Falcon compatible architectures." + ) + + new_layers = [] + for start, end in layer_map: + for i in range(start, end): + current_idx = len(new_layers) + new_layers.append(clone_module(layers[i], share_weights=True)) + # This is a hack needed to work around the layer_idx introduced in HF transformers. + for submodule in new_layers[-1].modules(): + if hasattr(submodule, "layer_idx"): + submodule.layer_idx = current_idx + layers = nn.CellList(new_layers) + if model_type == "llama": + model.layers = layers + elif model_type == "bert": + model.encoder.layer = layers + elif model_type == "falcon": + model.h = layers + else: + raise ValueError("Unexpected model type, need to handle post-processing of layers.") + if hasattr(model.config, "num_hidden_layers"): # Common to Llama, Bert, Falcon. + model.config.num_hidden_layers = len(new_layers) diff --git a/mindnlp/peft/utils/__init__.py b/mindnlp/peft/utils/__init__.py index 8f4131542..e4f6c81d2 100644 --- a/mindnlp/peft/utils/__init__.py +++ b/mindnlp/peft/utils/__init__.py @@ -15,19 +15,8 @@ """utils for peft""" from .peft_types import PeftType, TaskType from .other import ( - # TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, - TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, - TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, - TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, - TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, - COMMON_LAYERS_PATTERN, - CONFIG_NAME, - WEIGHTS_NAME, - # SAFETENSORS_WEIGHTS_NAME, - CLAMP_QUANTILE, _set_trainable, # add_library_to_model_card, - # bloom_model_postprocess_past_key_value, # prepare_model_for_int8_training, # prepare_model_for_kbit_training, shift_tokens_right, @@ -43,3 +32,14 @@ ) # from .hub_utils import hub_file_exists from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict, load_peft_weights +from .constants import ( + bloom_model_postprocess_past_key_value, + TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, + CONFIG_NAME, + WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, +) diff --git a/mindnlp/peft/utils/constants.py b/mindnlp/peft/utils/constants.py new file mode 100644 index 000000000..8ca8edcf9 --- /dev/null +++ b/mindnlp/peft/utils/constants.py @@ -0,0 +1,215 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""constants for peft""" +from mindspore import ops + +# needed for prefix-tuning of bloom model +def bloom_model_postprocess_past_key_value(past_key_values): + past_key_values = ops.cat(past_key_values) + total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape + keys = past_key_values[: total_layers // 2] + keys = keys.swapaxes(2, 3).reshape( + total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens + ) + values = past_key_values[total_layers // 2 :] + values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim) + + return tuple(zip(keys, values)) + + +# needed for prefix-tuning of StarCoder models +def starcoder_model_postprocess_past_key_value(past_key_values): + result = [] + for k in past_key_values: + k = k[:, :, 0] + k = k.permute([1, 2, 0, 3]) + k = k.reshape(*k.shape[:-2], -1) + result.append(k) + return tuple(result) + + +TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = { + "bloom": bloom_model_postprocess_past_key_value, + "gpt_bigcode": starcoder_model_postprocess_past_key_value, +} + +TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING = { + "llama": ["input_layernorm", "post_attention_layernorm", "norm"], + "bloom": ["input_layernorm", "post_attention_layernorm", "ln_f"], + "llava": [ + "multi_modal_projector", + "input_layernorm", + "post_attention_layernorm", + "norm", + "embed_tokens", + "lm_head", + ], + "t5": ["layer_norm", "final_layer_norm"], + "mt5": ["layer_norm", "final_layer_norm"], + "bart": ["self_attn_layer_norm", "encoder_attn_layer_norm", "final_layer_norm"], + "gpt2": ["ln_1", "ln_2", "ln_f"], + "blip-2": ["layernorm", "LayerNorm", "final_layer_norm", "self_attn_layer_norm"], + "gptj": ["ln_1", "ln_f"], + "falcon": ["input_layernorm", "post_attention_layernorm", "ln_f"], + "mistral": ["input_layernorm", "post_attention_layernorm", "norm"], + "phi": ["input_layernorm", "final_layernorm"], + "gemma": ["input_layernorm", "post_attention_layernorm", "norm"], +} + +TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = { + "t5": ["q", "v"], + "mt5": ["q", "v"], + "bart": ["q_proj", "v_proj"], + "gpt2": ["c_attn"], + "bloom": ["query_key_value"], + "blip-2": ["q", "v", "q_proj", "v_proj"], + "opt": ["q_proj", "v_proj"], + "gptj": ["q_proj", "v_proj"], + "gpt_neox": ["query_key_value"], + "gpt_neo": ["q_proj", "v_proj"], + "bert": ["query", "value"], + "roberta": ["query", "value"], + "xlm-roberta": ["query", "value"], + "electra": ["query", "value"], + "deberta-v2": ["query_proj", "value_proj"], + "deberta": ["in_proj"], + "layoutlm": ["query", "value"], + "llama": ["q_proj", "v_proj"], + "chatglm": ["query_key_value"], + "gpt_bigcode": ["c_attn"], + "mpt": ["Wqkv"], + "RefinedWebModel": ["query_key_value"], + "RefinedWeb": ["query_key_value"], + "falcon": ["query_key_value"], + "btlm": ["c_proj", "c_attn"], + "codegen": ["qkv_proj"], + "mistral": ["q_proj", "v_proj"], + "mixtral": ["q_proj", "v_proj"], + "stablelm": ["q_proj", "v_proj"], + "phi": ["q_proj", "v_proj", "fc1", "fc2"], + "gemma": ["q_proj", "v_proj"], +} + +TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING = { + "t5": ["k", "v", "wo"], + "mt5": ["k", "v", "wi_1"], + "gpt2": ["c_attn", "mlp.c_proj"], + "bloom": ["query_key_value", "mlp.dense_4h_to_h"], + "roberta": ["key", "value", "output.dense"], + "opt": ["q_proj", "k_proj", "fc2"], + "gptj": ["q_proj", "v_proj", "fc_out"], + "gpt_neox": ["query_key_value", "dense_4h_to_h"], + "gpt_neo": ["q_proj", "v_proj", "c_proj"], + "bart": ["q_proj", "v_proj", "fc2"], + "gpt_bigcode": ["c_attn", "mlp.c_proj"], + "llama": ["k_proj", "v_proj", "down_proj"], + "mistral": ["k_proj", "v_proj", "down_proj"], + "mixtral": ["k_proj", "v_proj", "w2"], + "bert": ["key", "value", "output.dense"], + "deberta-v2": ["key_proj", "value_proj", "output.dense"], + "deberta": ["in_proj", "output.dense"], + "RefinedWebModel": ["query_key_value", "dense_4h_to_h"], + "RefinedWeb": ["query_key_value", "dense_4h_to_h"], + "falcon": ["query_key_value", "dense_4h_to_h"], + "phi": ["q_proj", "v_proj", "fc2"], + "gemma": ["q_proj", "v_proj", "down_proj"], +} + +TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING = { + "t5": ["wo"], + "mt5": [], + "gpt2": ["mlp.c_proj"], + "bloom": ["mlp.dense_4h_to_h"], + "roberta": ["output.dense"], + "opt": ["fc2"], + "gptj": ["fc_out"], + "gpt_neox": ["dense_4h_to_h"], + "gpt_neo": ["c_proj"], + "bart": ["fc2"], + "gpt_bigcode": ["mlp.c_proj"], + "llama": ["down_proj"], + "mistral": ["down_proj"], + "mixtral": ["w2"], + "bert": ["output.dense"], + "deberta-v2": ["output.dense"], + "deberta": ["output.dense"], + "RefinedWeb": ["dense_4h_to_h"], + "RefinedWebModel": ["dense_4h_to_h"], + "falcon": ["dense_4h_to_h"], + "phi": ["fc2"], + "gemma": ["down_proj"], +} + +TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING = { + "t5": ["q", "k", "v", "o", "wi", "wo"], + "mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"], + "bart": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], + "gpt2": ["c_attn"], + "bloom": ["query_key_value"], + "opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], + "gptj": ["q_proj", "v_proj"], + "gpt_neox": ["query_key_value"], + "gpt_neo": ["q_proj", "v_proj"], + "llama": ["q_proj", "v_proj"], + "bert": ["query", "value"], + "roberta": ["query", "key", "value", "dense"], + # "xlm-roberta": ["query", "value"], + # "electra": ["query", "value"], + "deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"], + "gpt_bigcode": ["c_attn"], + "deberta": ["in_proj"], + # "layoutlm": ["query", "value"], +} + +TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING = { + "t5": ["q", "v"], + "mt5": ["q", "v"], + "bart": ["q_proj", "v_proj"], + "gpt2": ["c_attn"], + "bloom": ["query_key_value"], + "blip-2": ["q", "v", "q_proj", "v_proj"], + "opt": ["q_proj", "v_proj"], + "gptj": ["q_proj", "v_proj"], + "gpt_neox": ["query_key_value"], + "gpt_neo": ["q_proj", "v_proj"], + "bert": ["query", "value"], + "roberta": ["query", "value"], + "xlm-roberta": ["query", "value"], + "electra": ["query", "value"], + "deberta-v2": ["query_proj", "value_proj"], + "deberta": ["in_proj"], + "layoutlm": ["query", "value"], + "llama": ["q_proj", "v_proj"], + "chatglm": ["query_key_value"], + "gpt_bigcode": ["c_attn"], + "mpt": ["Wqkv"], + "RefinedWebModel": ["query_key_value"], + "RefinedWeb": ["query_key_value"], + "falcon": ["query_key_value"], + # "btlm": ["c_proj", "c_attn"], # tested, does not work because of different shapes + "codegen": ["qkv_proj"], + # "mistral": ["q_proj", "v_proj"], # tested, does not work because of different shapes + # "mixtral": ["q_proj", "v_proj"], # tested, does not work because of different shapes + "stablelm": ["q_proj", "v_proj"], + # "phi": ["q_proj", "v_proj", "fc1", "fc2"], # tested, does not work because of different shapes + "phi": ["q_proj", "v_proj"], + # "gemma": ["q_proj", "v_proj"], # tested, does not work because of different shapes +} + +WEIGHTS_NAME = "adapter_model.bin" +SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" +CONFIG_NAME = "adapter_config.json" +EMBEDDING_LAYER_NAMES = ["embed_tokens", "lm_head"] +INCLUDE_LINEAR_LAYERS_SHORTHAND = "all-linear" +TOKENIZER_CONFIG_NAME = "tokenizer_config.json" diff --git a/mindnlp/peft/utils/merge_utils.py b/mindnlp/peft/utils/merge_utils.py new file mode 100644 index 000000000..9f2c7c77f --- /dev/null +++ b/mindnlp/peft/utils/merge_utils.py @@ -0,0 +1,269 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""merge utils""" +import warnings +from typing import List, Literal + +import mindspore +from mindspore import ops + + +def reshape_weight_task_tensors(task_tensors, weights): + """ + Reshapes `weights` to match the shape of `task_tensors` by unsqeezing in the remaining dimenions. + + Args: + task_tensors (`mindspore.Tensor`): The tensors that will be used to reshape `weights`. + weights (`mindspore.Tensor`): The tensor to be reshaped. + + Returns: + `mindspore.Tensor`: The reshaped tensor. + """ + new_shape = weights.shape + (1,) * (task_tensors.ndim - weights.ndim) + weights = weights.view(new_shape) + return weights + + +def magnitude_based_pruning(tensor: mindspore.Tensor, density: float) -> mindspore.Tensor: + """ + Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction + `density`. + + Args: + tensor (`mindspore.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + + Returns: + `mindspore.Tensor`: The tensor with the pruned weights. + """ + mask = ops.zeros_like(tensor).reshape(-1) + k = int(density * tensor.numel()) + top_k = ops.topk(tensor.abs().reshape(-1), k=k, largest=True) + mask[top_k[1]] = 1 + return tensor * mask.reshape(tensor.shape) + + +def random_pruning(tensor: mindspore.Tensor, density: float, rescale: bool) -> mindspore.Tensor: + """ + Prune random values based on the specified fraction `density`. + + Args: + tensor (`mindspore.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. + + Returns: + `mindspore.Tensor`: The pruned tensor. + """ + mask = ops.bernoulli(ops.full_like(input=tensor, fill_value=density)) + pruned_tensor = tensor * mask + if rescale: + ops.div(input=pruned_tensor, other=density) + return pruned_tensor + + +def prune( + tensor: mindspore.Tensor, density: float, method: Literal["magnitude", "random"], rescale: bool = False +) -> mindspore.Tensor: + """ + Prune the values of task tensors based on the `method`. + + Args: + tensor (`mindspore.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + method (`str`):The method to use to prune. Should be one of ["magnitude", "random"]. + rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. + + Returns: + `mindspore.Tensor`: The pruned tensor. + """ + if density >= 1: + warnings.warn(f"The density {density} is greater than or equal to 1, no pruning will be performed.") + return tensor + elif density < 0: + raise ValueError(f"Density should be >= 0, got {density}") + if method == "magnitude": + return magnitude_based_pruning(tensor, density) + elif method == "random": + return random_pruning(tensor, density, rescale=rescale) + else: + raise ValueError(f"Unknown method {method}") + + +def calculate_majority_sign_mask( + tensor: mindspore.Tensor, method: Literal["total", "frequency"] = "total" +) -> mindspore.Tensor: + """ + Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0. + + Args: + tensor (`mindspore.Tensor`):The tensor to get the mask from. + method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"]. + + Returns: + `mindspore.Tensor`: The majority sign mask. + """ + + sign = tensor.sign() + if method == "total": + sign_magnitude = tensor.sum(dim=0) + elif method == "frequency": + sign_magnitude = sign.sum(dim=0) + else: + raise RuntimeError(f'Unimplemented mask method "{method}"') + majority_sign = ops.where(sign_magnitude >= 0, 1, -1) + return sign == majority_sign + + +def disjoint_merge(task_tensors: mindspore.Tensor, majority_sign_mask: mindspore.Tensor) -> mindspore.Tensor: + """ + Merge the task tensors using disjoint merge. + + Args: + task_tensors (`mindspore.Tensor`):The task tensors to merge. + majority_sign_mask (`mindspore.Tensor`):The mask of the majority sign across the task tensors. + + Returns: + `mindspore.Tensor`: The merged tensor. + """ + mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0) + num_params_preserved = majority_sign_mask.sum(dim=0) + return mixed_task_tensors / ops.clamp(num_params_preserved, min=1.0) + + +def task_arithmetic(task_tensors: List[mindspore.Tensor], weights: mindspore.Tensor) -> mindspore.Tensor: + """ + Merge the task tensors using `task arithmetic`. + + Args: + task_tensors(`List[mindspore.Tensor]`):The task tensors to merge. + weights (`mindspore.Tensor`):The weights of the task tensors. + + Returns: + `mindspore.Tensor`: The merged tensor. + """ + task_tensors = ops.stack(task_tensors, axis=0) + # weighted task tensors + weights = reshape_weight_task_tensors(task_tensors, weights) + weighted_task_tensors = task_tensors * weights + mixed_task_tensors = weighted_task_tensors.sum(dim=0) + return mixed_task_tensors + + +def magnitude_prune(task_tensors: List[mindspore.Tensor], weights: mindspore.Tensor, density: float) -> mindspore.Tensor: + """ + Merge the task tensors using `task arithmetic`. + + Args: + task_tensors(`List[mindspore.Tensor]`):The task tensors to merge. + weights (`mindspore.Tensor`):The weights of the task tensors. + density (`float`): The fraction of values to preserve. Should be in [0,1]. + + Returns: + `mindspore.Tensor`: The merged tensor. + """ + # sparsify + task_tensors = [prune(tensor, density, method="magnitude") for tensor in task_tensors] + task_tensors = ops.stack(task_tensors, axis=0) + # weighted task tensors + weights = reshape_weight_task_tensors(task_tensors, weights) + weighted_task_tensors = task_tensors * weights + mixed_task_tensors = weighted_task_tensors.sum(dim=0) + return mixed_task_tensors + + +def ties( + task_tensors: List[mindspore.Tensor], + weights: mindspore.Tensor, + density: float, + majority_sign_method: Literal["total", "frequency"] = "total", +) -> mindspore.Tensor: + """ + Merge the task tensors using `ties`. + + Args: + task_tensors(`List[mindspore.Tensor]`):The task tensors to merge. + weights (`mindspore.Tensor`):The weights of the task tensors. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + majority_sign_method (`str`): + The method to use to get the majority sign mask. Should be one of ["total", "frequency"]. + + Returns: + `mindspore.Tensor`: The merged tensor. + """ + # sparsify + task_tensors = [prune(tensor, density, method="magnitude") for tensor in task_tensors] + task_tensors = ops.stack(task_tensors, axis=0) + # Elect Sign + majority_sign_mask = calculate_majority_sign_mask(task_tensors, method=majority_sign_method) + # weighted task tensors + weights = reshape_weight_task_tensors(task_tensors, weights) + weighted_task_tensors = task_tensors * weights + # Disjoint Merge + mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask) + return mixed_task_tensors + + +def dare_linear(task_tensors: List[mindspore.Tensor], weights: mindspore.Tensor, density: float) -> mindspore.Tensor: + """ + Merge the task tensors using `dare linear`. + + Args: + task_tensors(`List[mindspore.Tensor]`):The task tensors to merge. + weights (`mindspore.Tensor`):The weights of the task tensors. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + + Returns: + `mindspore.Tensor`: The merged tensor. + """ + # sparsify + task_tensors = [prune(tensor, density, method="random", rescale=True) for tensor in task_tensors] + task_tensors = ops.stack(task_tensors, axis=0) + # weighted task tensors + weights = reshape_weight_task_tensors(task_tensors, weights) + weighted_task_tensors = task_tensors * weights + mixed_task_tensors = weighted_task_tensors.sum(dim=0) + return mixed_task_tensors + + +def dare_ties( + task_tensors: List[mindspore.Tensor], + weights: mindspore.Tensor, + density: float, + majority_sign_method: Literal["total", "frequency"] = "total", +) -> mindspore.Tensor: + """ + Merge the task tensors using `dare ties`. + + Args: + task_tensors(`List[mindspore.Tensor]`):The task tensors to merge. + weights (`mindspore.Tensor`):The weights of the task tensors. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + majority_sign_method (`str`): + The method to use to get the majority sign mask. Should be one of ["total", "frequency"]. + + Returns: + `mindspore.Tensor`: The merged tensor. + """ + # sparsify + task_tensors = [prune(tensor, density, method="random", rescale=True) for tensor in task_tensors] + task_tensors = ops.stack(task_tensors, axis=0) + # Elect Sign + majority_sign_mask = calculate_majority_sign_mask(task_tensors, method=majority_sign_method) + # weighted task tensors + weights = reshape_weight_task_tensors(task_tensors, weights) + weighted_task_tensors = task_tensors * weights + # Disjoint Merge + mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask) + return mixed_task_tensors diff --git a/mindnlp/peft/utils/other.py b/mindnlp/peft/utils/other.py index b0bb5b509..fb0f43d32 100644 --- a/mindnlp/peft/utils/other.py +++ b/mindnlp/peft/utils/other.py @@ -21,7 +21,6 @@ from mindspore.common.initializer import initializer, Normal from mindnlp._legacy.nn import Matmul -from mindnlp._legacy.abc import CellDict def _get_batch_size(input_ids: Optional[Tensor], inputs_embeds: Optional[Tensor]) -> int: """Get the batch size based on either input_ids or input_embeds @@ -45,7 +44,7 @@ class ModulesToSaveWrapper(mindspore.nn.Cell): def __init__(self, module_to_save, adapter_name): super().__init__() self.original_module = module_to_save - self.modules_to_save = CellDict() + self.modules_to_save = nn.CellDict() self.update(adapter_name) self.active_adapter = adapter_name self.disable_adapters = False @@ -127,7 +126,6 @@ def _set_trainable(model, adapter_name): parent, target, target_name = _get_submodules(model, key) if isinstance(target, ModulesToSaveWrapper): - # 判断是否是此数据类型 target.update(adapter_name) else: for _, param in target.parameters_and_names(): @@ -142,6 +140,9 @@ def _set_trainable(model, adapter_name): if isinstance(parent, nn.SequentialCell): parent.cell_list = list(parent._cells.values()) + for n, p in model.parameters_and_names(): + if n != p.name: + p.name = n def _freeze_adapter(model, adapter_name): @@ -250,113 +251,3 @@ def construct(self, x): x = self.matmul(x.view(-1, x.shape[-1]), self.weight) + self.bias x = x.view(size_out) return x - - -# Target_modules mapping (model -> qkv), which is highly related to **Mindnlp model** implementation. -# lora -TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = { - "t5": ["q", "v"], - # "mt5": ["q", "v"], - "bart": ["q_proj", "v_proj"], - "gpt2": ["c_attn"], - "bloom": ["query_key_value"], - # "blip-2": ["q", "v", "q_proj", "v_proj"], - "opt": ["q_proj", "v_proj"], - # "gptj": ["q_proj", "v_proj"], - # "gpt_neox": ["query_key_value"], - "gpt_neo": ["q_proj", "v_proj"], - "bert": ["query", "value"], - "roberta": ["query", "value"], - # "xlm-roberta": ["query", "value"], - # "electra": ["query", "value"], - # "deberta-v2": ["query_proj", "value_proj"], - "deberta": ["in_proj"], - # "layoutlm": ["query", "value"], - "llama": ["q_proj", "v_proj"], - "chatglm": ["query_key_value"], - # "gpt_bigcode": ["c_attn"], - # "mpt": ["Wqkv"], - # "RefinedWebModel": ["query_key_value"], - # "RefinedWeb": ["query_key_value"], - "falcon": ["query_key_value"], - # "btlm": ["c_proj", "c_attn"], - # "codegen": ["qkv_proj"], -} - -TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING = { - "t5": ["k", "v", "wo"], - "mt5": ["k", "v", "wi_1"], - "gpt2": ["c_attn", "mlp.c_proj"], - "bloom": ["query_key_value", "mlp.dense_4h_to_h"], - "roberta": ["key", "value", "output.dense"], - "opt": ["q_proj", "k_proj", "fc2"], - "gptj": ["q_proj", "v_proj", "fc_out"], - "gpt_neox": ["query_key_value", "dense_4h_to_h"], - "gpt_neo": ["q_proj", "v_proj", "c_proj"], - "bart": ["q_proj", "v_proj", "fc2"], - "gpt_bigcode": ["c_attn", "mlp.c_proj"], - "llama": ["k_proj", "v_proj", "down_proj"], - "mistral": ["k_proj", "v_proj", "down_proj"], - "mixtral": ["k_proj", "v_proj", "w2"], - "bert": ["key", "value", "output.dense"], - "deberta-v2": ["key_proj", "value_proj", "output.dense"], - "deberta": ["in_proj", "output.dense"], - "RefinedWebModel": ["query_key_value", "dense_4h_to_h"], - "RefinedWeb": ["query_key_value", "dense_4h_to_h"], - "falcon": ["query_key_value", "dense_4h_to_h"], - "phi": ["q_proj", "v_proj", "fc2"], - "gemma": ["q_proj", "v_proj", "down_proj"], -} - -TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING = { - "t5": ["wo"], - "mt5": [], - "gpt2": ["mlp.c_proj"], - "bloom": ["mlp.dense_4h_to_h"], - "roberta": ["output.dense"], - "opt": ["fc2"], - "gptj": ["fc_out"], - "gpt_neox": ["dense_4h_to_h"], - "gpt_neo": ["c_proj"], - "bart": ["fc2"], - "gpt_bigcode": ["mlp.c_proj"], - "llama": ["down_proj"], - "mistral": ["down_proj"], - "mixtral": ["w2"], - "bert": ["output.dense"], - "deberta-v2": ["output.dense"], - "deberta": ["output.dense"], - "RefinedWeb": ["dense_4h_to_h"], - "RefinedWebModel": ["dense_4h_to_h"], - "falcon": ["dense_4h_to_h"], - "phi": ["fc2"], - "gemma": ["down_proj"], -} -COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks", "layer"] -TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING = { - "t5": ["q", "k", "v", "o", "wi", "wo"], - # "mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"], - "bart": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], - "gpt2": ["c_attn"], - "bloom": ["query_key_value"], - "opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], - # "gptj": ["q_proj", "v_proj"], - # "gpt_neox": ["query_key_value"], - "gpt_neo": ["q_proj", "v_proj"], - "llama": ["q_proj", "v_proj"], - "bert": ["query", "value"], - "roberta": ["query", "key", "value", "dense"], - # "xlm-roberta": ["query", "value"], - # "electra": ["query", "value"], - "deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"], - "gpt_bigcode": ["c_attn"], - "deberta": ["in_proj"], - # "layoutlm": ["query", "value"], - "falcon": ["query_key_value"], -} - - -WEIGHTS_NAME = "adapter_model.ckpt" -CONFIG_NAME = "adapter_config.json" - -CLAMP_QUANTILE = 0.99 diff --git a/mindnlp/peft/utils/peft_types.py b/mindnlp/peft/utils/peft_types.py index 87119ae9b..e5ae73f5f 100644 --- a/mindnlp/peft/utils/peft_types.py +++ b/mindnlp/peft/utils/peft_types.py @@ -19,16 +19,40 @@ class PeftType(str, enum.Enum): """ - PeftType. + Enum class for the different types of adapters in PEFT. + + Supported PEFT types: + - PROMPT_TUNING + - MULTITASK_PROMPT_TUNING + - P_TUNING + - PREFIX_TUNING + - LORA + - ADALORA + - BOFT + - ADAPTION_PROMPT + - IA3 + - LOHA + - LOKR + - OFT + - POLY + - LN_TUNING """ - # PROMPT_TUNING = "PROMPT_TUNING" - # P_TUNING = "P_TUNING" - # PREFIX_TUNING = "PREFIX_TUNING" + + PROMPT_TUNING = "PROMPT_TUNING" + MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING" + P_TUNING = "P_TUNING" + PREFIX_TUNING = "PREFIX_TUNING" LORA = "LORA" ADALORA = "ADALORA" + BOFT = "BOFT" ADAPTION_PROMPT = "ADAPTION_PROMPT" IA3 = "IA3" - LOKR= "LOKR" + LOHA = "LOHA" + LOKR = "LOKR" + OFT = "OFT" + POLY = "POLY" + LN_TUNING = "LN_TUNING" + VERA = "VERA" class TaskType(str, enum.Enum): diff --git a/mindnlp/peft/utils/save_and_load.py b/mindnlp/peft/utils/save_and_load.py index 822f83f42..1dc679f6a 100644 --- a/mindnlp/peft/utils/save_and_load.py +++ b/mindnlp/peft/utils/save_and_load.py @@ -19,7 +19,7 @@ import mindspore from .peft_types import PeftType -from .other import WEIGHTS_NAME +from .constants import WEIGHTS_NAME def get_data_list(model: mindspore.nn.Cell): """Get state dict of the Peft model for saving.""" diff --git a/mindnlp/transformers/models/auto/modeling_auto.py b/mindnlp/transformers/models/auto/modeling_auto.py index 0cb6272f7..94ba91ea9 100644 --- a/mindnlp/transformers/models/auto/modeling_auto.py +++ b/mindnlp/transformers/models/auto/modeling_auto.py @@ -398,9 +398,10 @@ ("qwen2", "Qwen2ForSequenceClassification"), ("qwen2_moe", "Qwen2MoeForSequenceClassification"), ("reformer", "ReformerForSequenceClassification"), + ("roberta", "RobertaForSequenceClassification"), ("starcoder2", "Starcoder2ForSequenceClassification"), ("xlm-roberta", "XLMRobertaForSequenceClassification"), - ("xlnet", "XLNetForSequenceClassification"), + ("xlnet", "XLNetForSequenceClassification"), ] ) diff --git a/mindnlp/utils/download.py b/mindnlp/utils/download.py index e0a0132df..45d3d5eca 100644 --- a/mindnlp/utils/download.py +++ b/mindnlp/utils/download.py @@ -873,6 +873,8 @@ def build_download_url( raise ValueError('The mirror name not support, please use one of the mirror website below: ' '["huggingface", "modelscope", "wisemodel", "gitee", "aifast"]') if mirror in ('huggingface', 'gitee', 'modelscope', 'wisemodel'): + if mirror == 'modelscope' and revision == 'main': + revision = 'master' return MIRROR_MAP[mirror].format(repo_id, revision, filename) if revision is not None and revision != 'main': logger.warning(f'`revision` is not support when use "{mirror}" website. ' diff --git a/setup.py b/setup.py index 3c2ca3ada..bb9bbbecc 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ from setuptools.command.build_py import build_py -version = '0.3.0' +version = '0.3.1' cur_dir = os.path.dirname(os.path.realpath(__file__)) pkg_dir = os.path.join(cur_dir, 'build') diff --git a/tests/ut/dataset/transforms/test_lookup.py b/tests/ut/dataset/transforms/test_lookup.py index e1ff26849..f112875a8 100644 --- a/tests/ut/dataset/transforms/test_lookup.py +++ b/tests/ut/dataset/transforms/test_lookup.py @@ -15,7 +15,7 @@ """Test Transform""" from mindspore.dataset.text import Vocab as msVocab -from mindnlp import Vocab +from mindnlp.vocab import Vocab from mindnlp.dataset.transforms import Lookup def test_lookup(): diff --git a/tests/ut/vocab/test_vocab.py b/tests/ut/vocab/test_vocab.py index ee142e4de..0054572e5 100644 --- a/tests/ut/vocab/test_vocab.py +++ b/tests/ut/vocab/test_vocab.py @@ -14,7 +14,7 @@ # ============================================================================ """Test Vocab""" -from mindnlp import Vocab +from mindnlp.vocab import Vocab def test_vocab_from_list(): """test vocab from list.""" diff --git a/tests/ut/vocab/test_vocab_fasttext.py b/tests/ut/vocab/test_vocab_fasttext.py index 7490a3eb1..0d4f88cf0 100644 --- a/tests/ut/vocab/test_vocab_fasttext.py +++ b/tests/ut/vocab/test_vocab_fasttext.py @@ -14,7 +14,7 @@ # ============================================================================ """Test Fasttext Vocab""" import pytest -from mindnlp import Vocab +from mindnlp.vocab import Vocab @pytest.mark.download diff --git a/tests/ut/vocab/test_vocab_glove.py b/tests/ut/vocab/test_vocab_glove.py index ac5dbbf83..d5a6eb41f 100644 --- a/tests/ut/vocab/test_vocab_glove.py +++ b/tests/ut/vocab/test_vocab_glove.py @@ -15,7 +15,7 @@ """Test Glove Vocab""" import pytest -from mindnlp import Vocab +from mindnlp.vocab import Vocab @pytest.mark.download