From e14b1e4de98bd2ecb6fcf74a090e9c36247eff96 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Tue, 24 Dec 2024 20:36:23 +0800 Subject: [PATCH] fix ia3 --- llm/peft/ia3/sequence_classification.ipynb | 346 ++------------------- mindnlp/core/ops/random.py | 2 + mindnlp/peft/tuners/ia3/layer.py | 294 ++++------------- 3 files changed, 85 insertions(+), 557 deletions(-) diff --git a/llm/peft/ia3/sequence_classification.ipynb b/llm/peft/ia3/sequence_classification.ipynb index 28551bb78..f0d2beabf 100644 --- a/llm/peft/ia3/sequence_classification.ipynb +++ b/llm/peft/ia3/sequence_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "a9935ae2", "metadata": { "ExecuteTime": { @@ -11,34 +11,8 @@ }, "id": "a9935ae2" }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:518: UserWarning: The value of the smallest subnormal for type is zero.\n", - " setattr(self, word, getattr(machar, word).flat[0])\n", - "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n", - " return self._float_to_str(self.smallest_subnormal)\n", - "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:518: UserWarning: The value of the smallest subnormal for type is zero.\n", - " setattr(self, word, getattr(machar, word).flat[0])\n", - "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n", - " return self._float_to_str(self.smallest_subnormal)\n", - "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "Building prefix dict from the default dictionary ...\n", - "Loading model from cache /tmp/jieba.cache\n", - "Loading model cost 1.306 seconds.\n", - "Prefix dict has been built successfully.\n" - ] - } - ], + "outputs": [], "source": [ - "import argparse\n", - "import os\n", - "if \"RANK_TABLE_FILE\" in os.environ:\n", - " del os.environ[\"RANK_TABLE_FILE\"]\n", - "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n", "from tqdm import tqdm\n", "import mindspore\n", "from mindnlp.core.optim import AdamW\n", @@ -53,7 +27,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "1cbf9c55", "metadata": {}, "outputs": [], @@ -67,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "0526f571", "metadata": { "ExecuteTime": { @@ -78,14 +52,13 @@ }, "outputs": [], "source": [ - "# peft_config = LoraConfig(task_type=\"SEQ_CLS\", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)\n", "peft_config = peft.IA3Config(task_type=\"SEQ_CLS\", inference_mode=False)\n", "lr = 1e-3" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "4b78935c5491c50e", "metadata": { "ExecuteTime": { @@ -93,16 +66,7 @@ "start_time": "2024-09-09T13:57:05.069696Z" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ma-user/work/mindnlp/mindnlp/transformers/tokenization_utils_base.py:1526: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted, and will be then set to `False` by default. \n", - " warnings.warn(\n" - ] - } - ], + "outputs": [], "source": [ "if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n", " padding_side = \"left\"\n", @@ -116,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "f9f8b73ed0af54cd", "metadata": { "ExecuteTime": { @@ -124,15 +88,7 @@ "start_time": "2024-09-09T13:57:06.606526Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'sentence1': Tensor(shape=[], dtype=String, value= 'Amrozi accused his brother , whom he called \" the witness \" , of deliberately distorting his evidence .'), 'sentence2': Tensor(shape=[], dtype=String, value= 'Referring to him as only \" the witness \" , Amrozi accused his brother of deliberately distorting his evidence .'), 'label': Tensor(shape=[], dtype=Int64, value= 1), 'idx': Tensor(shape=[], dtype=Int64, value= 0)}\n" - ] - } - ], + "outputs": [], "source": [ "datasets = load_dataset(\"glue\", task)\n", "print(next(datasets['train'].create_dict_iterator()))" @@ -140,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "3077837a34694587", "metadata": { "ExecuteTime": { @@ -174,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "a6b2c5aa959ef429", "metadata": { "ExecuteTime": { @@ -182,36 +138,14 @@ "start_time": "2024-09-09T13:57:23.773086Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'input_ids': Tensor(shape=[8, 67], dtype=Int64, value=\n", - "[[ 0, 10127, 1001 ... 1, 1, 1],\n", - " [ 0, 975, 26802 ... 1, 1, 1],\n", - " [ 0, 1213, 56 ... 1, 1, 1],\n", - " ...\n", - " [ 0, 9064, 32497 ... 1, 1, 1],\n", - " [ 0, 133, 4417 ... 1, 1, 1],\n", - " [ 0, 133, 19888 ... 1, 1, 1]]), 'attention_mask': Tensor(shape=[8, 67], dtype=Int64, value=\n", - "[[1, 1, 1 ... 0, 0, 0],\n", - " [1, 1, 1 ... 0, 0, 0],\n", - " [1, 1, 1 ... 0, 0, 0],\n", - " ...\n", - " [1, 1, 1 ... 0, 0, 0],\n", - " [1, 1, 1 ... 0, 0, 0],\n", - " [1, 1, 1 ... 0, 0, 0]]), 'labels': Tensor(shape=[8], dtype=Int64, value= [1, 0, 1, 0, 1, 1, 0, 1])}\n" - ] - } - ], + "outputs": [], "source": [ "print(next(train_dataset.create_dict_iterator()))" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "c2697d07", "metadata": { "ExecuteTime": { @@ -438,13 +372,12 @@ }, "outputs": [], "source": [ - "\n", - "metric = evaluate.load(\"glue\", task)\n" + "metric = evaluate.load(\"glue\", task)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "2ed5ac74", "metadata": { "ExecuteTime": { @@ -471,30 +404,7 @@ "id": "2ed5ac74", "outputId": "18ea15ac-ed8d-4d80-b166-706681ee49ab" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[MS_ALLOC_CONF]Runtime config: enable_vmm:True vmm_align_size:2MB\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n", - "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "trainable params: 1,223,682 || all params: 356,585,476 || trainable%: 0.34316652874555104\n" - ] - } - ], + "outputs": [], "source": [ "model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)\n", "model = peft.get_peft_model(model, peft_config)\n", @@ -503,7 +413,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "0d2d0381", "metadata": { "ExecuteTime": { @@ -514,9 +424,7 @@ }, "outputs": [], "source": [ - "\n", - "optimizer = AdamW(params=model.parameters(), lr=lr)\n", - "\n", + "optimizer = AdamW(params=model.trainable_params(), lr=lr)\n", "# Instantiate scheduler\n", "lr_scheduler = get_linear_schedule_with_warmup(\n", " optimizer=optimizer,\n", @@ -542,209 +450,7 @@ }, "outputId": "bb17c146-8acc-477d-8f9f-65b8be794abb" }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 459/459 [05:46<00:00, 1.32it/s]\n", - " 0%| | 0/51 [00:00 None: - Initializes an instance of the IA3Layer class. - - update_layer(self, adapter_name, init_ia3_weights): - Updates the IA3Layer with the specified adapter name and initializes IA3 weights. - - reset_ia3_parameters(self, adapter_name): - Resets the IA3Layer parameters for the specified adapter name. - - """ # All names of layers that may contain adapter weights adapter_layer_names = ("ia3_l",) def __init__(self, base_layer: nn.Module, is_feedforward: bool, **kwargs) -> None: - r""" - Initialize the IA3Layer class. - - Args: - self: The instance of the IA3Layer class. - base_layer (nn.Module): The base layer used in the IA3Layer. - This parameter specifies the base layer (e.g., nn.Linear, nn.Conv2d, nn.Embedding, Conv1D) to be used in the IA3Layer. - is_feedforward (bool): A boolean flag indicating whether the IA3Layer is feedforward or not. - Set to True if the IA3Layer is feedforward, False otherwise. - - Returns: - None: This method does not return any value. - - Raises: - ValueError: If the provided base_layer is not supported or of an unsupported type. - """ self.base_layer = base_layer - self.ia3_l = ParameterDict({}) + self.ia3_l = nn.ParameterDict({}) # Mark the weight as unmerged self._disable_adapters = False self.merged_adapters = [] self.is_feedforward = is_feedforward + base_layer = self.get_base_layer() if isinstance(base_layer, nn.Linear): in_features, out_features = base_layer.in_features, base_layer.out_features - elif isinstance(base_layer, nn.Conv2d): + elif isinstance(base_layer, (nn.Conv2d, nn.Conv3d)): in_features, out_features = base_layer.in_channels, base_layer.out_channels elif isinstance(base_layer, nn.Embedding): in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim @@ -102,94 +53,24 @@ def __init__(self, base_layer: nn.Module, is_feedforward: bool, **kwargs) -> Non self.out_features = out_features def update_layer(self, adapter_name, init_ia3_weights): - r""" - Updates the IA3 layer with the given adapter name and initializes its weights if specified. - - Args: - self (IA3Layer): The IA3Layer instance. - adapter_name (str): The name of the adapter to update. - init_ia3_weights (bool): Flag indicating whether to initialize the IA3 weights. - - Returns: - None - - Raises: - None - """ # This code works for linear layers, override for other layer types # Actual trainable parameters if self.is_feedforward: weight = ops.randn((1, self.in_features)) else: weight = ops.randn((self.out_features, 1)) - self.ia3_l[adapter_name] = Parameter(weight) + self.ia3_l[adapter_name] = nn.Parameter(weight) if init_ia3_weights: self.reset_ia3_parameters(adapter_name) self.set_adapter(self.active_adapters) def reset_ia3_parameters(self, adapter_name): - r""" - Resets the IA3 parameters for a given adapter in the IA3Layer. - - Args: - self: The instance of the IA3Layer class. - adapter_name (str): The name of the adapter whose parameters need to be reset. - - Returns: - None. This method does not return any value. - - Raises: - None. - - This method resets the IA3 parameters for the specified adapter by setting its data to a constant value of 1.0 using the initializer function. The adapter_name parameter is used to identify the adapter -in the ia3_l dictionary. If the adapter_name is not found in the dictionary, no action is taken. - """ if adapter_name in self.ia3_l.keys(): - # initialize learned vector with torch.ones - self.ia3_l[adapter_name].assign_value(initializer( - Constant(1.0), - self.ia3_l[adapter_name].shape, - self.ia3_l[adapter_name].dtype - )) + # initialize learned vector with ops.ones + nn.init.constant_(self.ia3_l[adapter_name], 1.0) class Linear(nn.Module, IA3Layer): - - r""" - The `Linear` class represents a linear layer that inherits from `nn.Module` and `IA3Layer`. - - Summary: - This class implements a linear layer that can merge and unmerge adapter weights into the base weights. - - Attributes: - - `base_layer`: An instance of `nn.Module` representing the base layer. - - `adapter_name`: A string specifying the active adapter name. - - `fan_in_fan_out`: A boolean indicating whether to transpose the adapter weights. - - `is_feedforward`: A boolean indicating whether the layer is feedforward. - - `is_target_conv_1d_layer`: A boolean indicating whether the layer is a target convolutional 1D layer. - - `init_ia3_weights`: A boolean indicating whether to initialize IA3 weights. - - `merged_adapters`: A list of merged adapter names. - - Methods: - - `__init__(self, base_layer: nn.Module, adapter_name: str, fan_in_fan_out: bool = False, is_feedforward: bool = False, is_target_conv_1d_layer: bool = False, init_ia3_weights: bool = True, **kwargs) -> -None`: - Initializes a `Linear` instance with the given parameters. - - - `merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None`: - Merges the active adapter weights into the base weights. - - - `unmerge(self) -> None`: - Unmerges all merged adapter layers from the base weights. - - - `forward(self, x: Tensor, *args, **kwargs) -> Tensor`: - Constructs the linear layer with the given input tensor. - - Note: - - The `merge` method merges the active adapter weights into the base weights, allowing for adaptation. - - The `unmerge` method unmerges all merged adapter layers from the base weights. - - The `forward` method forwards the linear layer, taking into account adapter weights if applicable. - - """ # (IA)^3 implemented in a dense layer def __init__( self, @@ -197,29 +78,10 @@ def __init__( adapter_name: str, fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) is_feedforward: bool = False, # Set to True if the layer is treated as a feedforward layer - is_target_conv_1d_layer: bool = False, # whether target cell is a conv1d layer. useful while unloading later + is_target_conv_1d_layer: bool = False, # whether target module is a conv1d layer. useful while unloading later init_ia3_weights: bool = True, # whether to initialize IA3 weights **kwargs, ) -> None: - r""" - Initializes a Linear object. - - Args: - self: The instance of the Linear class. - base_layer (nn.Module): The base layer to be used for the Linear layer. - adapter_name (str): The name of the adapter. - fan_in_fan_out (bool): A flag indicating whether to use fan-in/fan-out weights. - is_feedforward (bool): A flag indicating whether the layer is feedforward. - is_target_conv_1d_layer (bool): A flag indicating whether the layer is a 1D convolutional layer. - init_ia3_weights (bool): A flag indicating whether to initialize IA3 weights. - **kwargs: Additional keyword arguments. - - Returns: - None. This method does not return any value. - - Raises: - None. - """ super().__init__() IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward) self.fan_in_fan_out = fan_in_fan_out @@ -249,6 +111,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N if active_adapter in self.ia3_l.keys(): base_layer = self.get_base_layer() ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out) + orig_dtype = base_layer.weight.data.dtype if safe_merge: orig_weights = base_layer.weight.data orig_weights = ops.mul(orig_weights, ia3_l) @@ -257,13 +120,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - base_layer.weight.data = orig_weights + base_layer.weight.data = orig_weights.to(orig_dtype) else: - base_layer.weight.data = ops.mul(base_layer.weight.data, ia3_l) + base_layer.weight.data = ops.mul(base_layer.weight.data, ia3_l).to(orig_dtype) if not self.is_feedforward and (base_layer.bias is not None): scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) - base_layer.bias.data = ops.mul(base_layer.bias.data, scaling.data) + orig_dtype = base_layer.bias.data.dtype + base_layer.bias.data = ops.mul(base_layer.bias.data, scaling.data).to(orig_dtype) self.merged_adapters.append(active_adapter) @@ -282,28 +146,15 @@ def unmerge(self) -> None: base_layer = self.get_base_layer() # Add tolerace to avoid division by zero ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out) + 1e-8 - base_layer.weight.data = ops.div(base_layer.weight.data, ia3_l) + orig_dtype = base_layer.weight.data.dtype + base_layer.weight.data = ops.div(base_layer.weight.data, ia3_l).to(orig_dtype) if not self.is_feedforward and (base_layer.bias is not None): scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) - base_layer.bias.data = ops.div(base_layer.bias.data, scaling.data + 1e-8) + orig_dtype = base_layer.bias.data.dtype + base_layer.bias.data = ops.div(base_layer.bias.data, scaling.data + 1e-8).to(orig_dtype) - def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor: - r""" - This method forwards a tensor using the input tensor 'x' and additional arguments and keyword arguments. It adapts the input tensor based on the configuration of the Linear class, including the use -of adapters and merging layers. - - Args: - x (Tensor): The input tensor to be processed. It should be of the type Tensor. - *args: Additional positional arguments that can be passed to the method. - **kwargs: Additional keyword arguments that can be passed to the method. - - Returns: - Tensor: The forwarded tensor based on the input 'x' and the configuration of the Linear class. - - Raises: - None: This method does not explicitly raise any exceptions. - """ + def forward(self, x: mindspore.Tensor, *args: Any, **kwargs: Any) -> mindspore.Tensor: dtype = previous_dtype = x.dtype if self.disable_adapters: if self.merged: @@ -321,32 +172,19 @@ def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor: if self.is_feedforward: x = x.to(dtype) - interm = (x * ia3_scaling).to(self.get_base_layer().weight.dtype) + # TODO: weight.dtype can be != self.ia3_l[self.active_adapters].dtype + # e.g. bf16 vs fp32. Is that okay? + interm = (x * ia3_scaling).to(previous_dtype) result = self.base_layer(interm, *args, **kwargs) else: result = self.base_layer(x, *args, **kwargs) - result = result.to(dtype) * ia3_scaling + result_dtype = result.dtype + result = (result * ia3_scaling).to(result_dtype) - result = result.to(previous_dtype) return result -class Conv2d(nn.Module, IA3Layer): - - r""" - The Conv2d class represents a convolutional neural network layer with adaptive scaling capabilities for adapter layers. - This class inherits from nn.Module and IA3Layer, allowing for flexible integration with existing neural network architectures. - The class provides methods for updating, merging, and unmerging adapter layers, as well as forwarding the final output based on the input tensor. - - Methods: - - __init__: Initialize the Conv2d layer with specified parameters and adapter settings. - - update_layer: Update the adapter layer with new weights based on the provided adapter name. - - merge: Merge active adapter weights into the base weights with optional safe merge checks. - - unmerge: Unmerge all previously merged adapter layers from the base weights. - - forward: Construct the output tensor based on the input tensor, considering adapter scaling and merging configurations. - - Note: The Conv2d class is designed to enhance neural network models with adaptive scaling functionality for improved performance and flexibility. - """ +class _ConvNd(nn.Module, IA3Layer): def __init__( self, base_layer: nn.Module, @@ -356,52 +194,20 @@ def __init__( init_ia3_weights: bool = True, **kwargs, ) -> None: - r""" - Initializes a new instance of the Conv2d class. - - Args: - self (Conv2d): The current instance of the Conv2d class. - base_layer (nn.Module): The base layer for the Conv2d operation. - adapter_name (str): The name of the adapter. - fan_in_fan_out (bool, optional): Flag indicating whether to use fan-in/fan-out initialization. Defaults to False. - is_feedforward (bool, optional): Flag indicating whether the Conv2d operation is feedforward. Defaults to False. - init_ia3_weights (bool, optional): Flag indicating whether to initialize IA3 weights. Defaults to True. - **kwargs: Additional keyword arguments. - - Returns: - None - - Raises: - None - """ super().__init__() IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward) self.fan_in_fan_out = fan_in_fan_out self._active_adapter = adapter_name + self._kernel_dim = base_layer.weight.dim() self.update_layer(adapter_name, init_ia3_weights) def update_layer(self, adapter_name, init_ia3_weights): - r""" - Updates the layer of the Conv2d class with the specified adapter name and initialization of IA3 weights. - - Args: - self (Conv2d): The instance of the Conv2d class. - adapter_name (str): The name of the adapter to be updated. - init_ia3_weights (bool): Indicates whether to initialize IA3 weights or not. - - Returns: - None - - Raises: - None - """ # Actual trainable parameters - if self.is_feedforward: - weight = ops.randn((1, self.in_features, 1, 1)) - else: - weight = ops.randn((1, self.out_features, 1, 1)) - self.ia3_l[adapter_name] = Parameter(weight) + num_features = self.in_features if self.is_feedforward else self.out_features + weights_size = (1, num_features) + (1,) * (self._kernel_dim - 2) + weight = ops.randn(weights_size) + self.ia3_l[adapter_name] = nn.Parameter(weight) if init_ia3_weights: self.reset_ia3_parameters(adapter_name) self.set_adapter(self.active_adapters) @@ -429,7 +235,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N base_layer = self.get_base_layer() ia3_scaling = self.ia3_l[active_adapter].data if not self.is_feedforward: - ia3_scaling = ia3_scaling.permute(1, 0, 2, 3) + ia3_scaling = ia3_scaling.transpose(0, 1) if safe_merge: output_weight = ops.mul(base_layer.weight.data, ia3_scaling).clone() @@ -465,27 +271,14 @@ def unmerge(self) -> None: # divide by (IA)^3 vector. Add tolerace to avoid division by zero ia3_scaling = self.ia3_l[active_adapter].data if not self.is_feedforward: - ia3_scaling = ia3_scaling.permute(1, 0, 2, 3) + ia3_scaling = ia3_scaling.transpose(0, 1) base_layer.weight.data = ops.div(base_layer.weight.data, ia3_scaling + 1e-8) if not self.is_feedforward and (base_layer.bias is not None): scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) base_layer.bias.data = ops.mul(base_layer.bias.data, scaling.data) - def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor: - r""" - Construct method for the Conv2d class. - - Args: - self: The instance of the Conv2d class. - x (Tensor): The input tensor representing the input data. It is the primary input to the forward method. - - Returns: - Tensor: The output tensor after processing the input data through the forward method. The type and shape of the tensor is dependent on the operation performed within the method. - - Raises: - N/A - """ + def forward(self, x: mindspore.Tensor, *args: Any, **kwargs: Any) -> mindspore.Tensor: dtype = previous_dtype = x.dtype if self.disable_adapters: @@ -504,6 +297,8 @@ def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor: if self.is_feedforward: x = x.to(dtype) + # TODO: weight.dtype can be != self.ia3_l[self.active_adapters].dtype + # e.g. bf16 vs fp32. Is that okay? interm = (x * ia3_scaling).to(self.get_base_layer().weight.dtype) result = self.base_layer(interm, *args, **kwargs) else: @@ -512,3 +307,20 @@ def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor: result = result.to(previous_dtype) return result + + +class Conv2d(_ConvNd): + # IA3 implemented in a 2D convolutional layer + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self._kernel_dim == 4: + raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}") + + +class Conv3d(_ConvNd): + # IA3 implemented in a 3D convolutional layer + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self._kernel_dim == 5: + raise ValueError(f"Conv2d layer kernel must have 5 dimensions, not {self._kernel_dim}")