diff --git a/mindnlp/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/mindnlp/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 3cbf820e7..c3a440799 100644 --- a/mindnlp/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/mindnlp/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -17,17 +17,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""MindSpore Qwen2MoE model.""" +"""MindSpore Qwen2Moe model.""" import math from typing import List, Optional, Tuple, Union import mindspore +from mindspore import mint import mindnlp.core.nn.functional as F -from mindnlp.core import nn, ops +from mindnlp.core import nn, ops, get_default_dtype from mindnlp.core.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from ....common.activations import ACT2FN +from mindnlp.configs import USE_PYBOOST +from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_outputs import ( MoeCausalLMOutputWithPast, @@ -35,7 +36,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import * from ....utils import logging from .configuration_qwen2_moe import Qwen2MoeConfig @@ -76,14 +77,14 @@ def _prepare_4d_causal_attention_mask_with_cache_position( batch_size (`mindspore.Tensor`): Batch size. """ - if attention_mask is not None and attention_mask.ndim == 4: + if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: causal_mask = ops.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype) if sequence_length != 1: causal_mask = ops.triu(causal_mask, diagonal=1) - causal_mask *= ops.arange(target_length) > cache_position.reshape(-1, 1) + causal_mask *= mint.arange(target_length) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].broadcast_to((batch_size, 1, -1, -1)) if attention_mask is not None: causal_mask = causal_mask.copy() # copy to contiguous memory for in-place edit @@ -97,12 +98,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func def load_balancing_loss_func( gate_logits: mindspore.Tensor, num_experts: mindspore.Tensor = None, top_k=2, attention_mask: Optional[mindspore.Tensor] = None ) -> float: r""" - Computes auxiliary load balancing loss as in Switch Transformer. + Computes auxiliary load balancing loss as in Switch Transformer - implemented in MindSpore. See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between @@ -123,22 +123,21 @@ def load_balancing_loss_func( """ if gate_logits is None or not isinstance(gate_logits, tuple): return 0 - if isinstance(gate_logits, tuple): - concatenated_gate_logits = ops.cat(list(gate_logits), dim=0) + concatenated_gate_logits = mint.cat(list(gate_logits), dim=0) routing_weights = nn.functional.softmax(concatenated_gate_logits, dim=-1) - _, selected_experts = ops.topk(routing_weights, top_k, dim=-1) + _, selected_experts = mint.topk(routing_weights, top_k, dim=-1) expert_mask = nn.functional.one_hot(selected_experts, num_experts) if attention_mask is None: # Compute the percentage of tokens routed to each experts - tokens_per_expert = ops.mean(expert_mask.float(), dim=0) + tokens_per_expert = mint.mean(expert_mask.float(), dim=0) # Compute the average probability of routing to these experts - router_prob_per_expert = ops.mean(routing_weights, dim=0) + router_prob_per_expert = mint.mean(routing_weights, dim=0) else: batch_size, sequence_length = attention_mask.shape num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) @@ -151,7 +150,7 @@ def load_balancing_loss_func( ) # Compute the percentage of tokens routed to each experts - tokens_per_expert = ops.sum(expert_mask.float() * expert_attention_mask, dim=0) / ops.sum( + tokens_per_expert = mint.sum(expert_mask.float() * expert_attention_mask, dim=0) / mint.sum( expert_attention_mask, dim=0 ) @@ -163,11 +162,11 @@ def load_balancing_loss_func( ) # Compute the average probability of routing to these experts - router_prob_per_expert = ops.sum(routing_weights * router_per_expert_attention_mask, dim=0) / ops.sum( + router_prob_per_expert = mint.sum(routing_weights * router_per_expert_attention_mask, dim=0) / mint.sum( router_per_expert_attention_mask, dim=0 ) - overall_loss = ops.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + overall_loss = mint.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts @@ -178,21 +177,23 @@ def __init__(self, hidden_size, eps=1e-6): Qwen2MoeRMSNorm is equivalent to T5LayerNorm """ super().__init__() - self.weight = nn.Parameter(ops.ones(hidden_size)) + self.weight = nn.Parameter(mint.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(mindspore.float32) - variance = ops.mean(hidden_states.pow(2), -1, keepdim=True) - hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states.to(self.weight.dtype) + if not self.training and USE_PYBOOST: + return F.rms_norm(hidden_states, self.weight, self.variance_epsilon).to(input_dtype) + variance = mint.mean(hidden_states.pow(2), -1, keepdim=True) + hidden_states = hidden_states * mint.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe +# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2Moe class Qwen2MoeRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000): super().__init__() @@ -200,21 +201,21 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (ops.arange(0, self.dim, 2, dtype=mindspore.int64).float() / self.dim)) + inv_freq = 1.0 / (self.base ** (mint.arange(0, self.dim, 2, dtype=mindspore.int64).float() / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `ops.jit.trace` work. self._set_cos_sin_cache( - seq_len=max_position_embeddings, dtype=ops.get_default_dtype() + seq_len=max_position_embeddings, dtype=get_default_dtype() ) def _set_cos_sin_cache(self, seq_len, dtype): self.max_seq_len_cached = seq_len - t = ops.arange(self.max_seq_len_cached, dtype=mindspore.int64).type_as(self.inv_freq) + t = mint.arange(self.max_seq_len_cached, dtype=mindspore.int64).type_as(self.inv_freq) freqs = ops.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = ops.cat((freqs, freqs), dim=-1) + emb = mint.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) @@ -224,64 +225,24 @@ def forward(self, x, seq_len=None): self._set_cos_sin_cache(seq_len=seq_len, dtype=x.dtype) return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), + mint.narrow(self.cos_cached,0,0,seq_len).to(dtype=x.dtype), + mint.narrow(self.sin_cached,0,0,seq_len).to(dtype=x.dtype), ) - # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return ops.cat((-x2, x1), dim=-1) - + x1,x2=mint.split(x,x.shape[-1]//2,dim=-1) + return mint.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`mindspore.Tensor`): The query tensor. - k (`mindspore.Tensor`): The key tensor. - cos (`mindspore.Tensor`): The cosine part of the rotary embedding. - sin (`mindspore.Tensor`): The sine part of the rotary embedding. - position_ids (`mindspore.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(mindspore.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + cos=F.embedding(position_ids,cos).unsqueeze(unsqueeze_dim) + sin=F.embedding(position_ids,sin).unsqueeze(unsqueeze_dim) + q_embed = mint.add(mint.mul(q, cos), mint.mul(rotate_half(q), sin)) + k_embed = mint.add(mint.mul(k, cos), mint.mul(rotate_half(k), sin)) return q_embed, k_embed -# Modified from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2Moe -class Qwen2MoeMLP(nn.Module): - def __init__(self, config, intermediate_size=None): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: mindspore.Tensor, n_rep: int) -> mindspore.Tensor: """ @@ -295,7 +256,7 @@ def repeat_kv(hidden_states: mindspore.Tensor, n_rep: int) -> mindspore.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe +# copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Qwen2Moe class Qwen2MoeAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -308,8 +269,8 @@ def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None): self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) @@ -339,6 +300,9 @@ def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None): base=self.rope_theta, ) + def _shape(self, tensor: mindspore.Tensor, seq_len: int, bsz: int): + return ops.transpose(tensor.view(bsz, seq_len, self.num_heads, self.head_dim), 1, 2) + def forward( self, hidden_states: mindspore.Tensor, @@ -374,27 +338,29 @@ def forward( if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if query_states.dtype==mindspore.half: + attn_output = mindspore.ops.flash_attention_score(query_states, key_states, value_states, query_states.shape[1],input_layout='BNSD',scalar_value=1/math.sqrt(self.head_dim),attn_mask=mint.narrow(attention_mask,-1,0,key_states.shape[-2]).bool()) + else: + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = ops.matmul(query_states, ops.transpose(key_states, 2, 3)) / math.sqrt(self.head_dim) + attn_weights = mint.matmul(query_states, ops.transpose(key_states, 2, 3)) / math.sqrt(self.head_dim) - if attn_weights.shape != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.shape}" - ) + if attn_weights.shape != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.shape}" + ) - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = mint.narrow(attention_mask,-1,0,key_states.shape[-2]) + attn_weights = attn_weights + causal_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=mindspore.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = ops.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=mindspore.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = mint.matmul(attn_weights, value_states) if attn_output.shape != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( @@ -412,12 +378,27 @@ def forward( return attn_output, attn_weights, past_key_value - -QWEN2MOE_ATTENTION_CLASSES = { +Qwen2Moe_ATTENTION_CLASSES = { "eager": Qwen2MoeAttention, } +# Modified from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2Moe +class Qwen2MoeMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + class Qwen2MoeSparseMoeBlock(nn.Module): def __init__(self, config): super().__init__() @@ -433,51 +414,39 @@ def __init__(self, config): self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size) self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias=False) - + self.copied=False + self.act=ACT2FN[config.hidden_act] def forward(self, hidden_states: mindspore.Tensor) -> mindspore.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=mindspore.float32) - routing_weights, selected_experts = ops.topk(routing_weights, self.top_k, dim=-1) + router_scores=mint.zeros_like(router_logits) + routing_weights = F.softmax(router_logits, dim=-1, dtype=mindspore.float32) + routing_weights, selected_experts = mint.topk(routing_weights, self.top_k, dim=-1) if self.norm_topk_prob: routing_weights /= ops.sum(routing_weights, dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) - - final_hidden_states = ops.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype - ) - - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - - # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] - idx, top_x = ops.nonzero(expert_mask[expert_idx], as_tuple=True) - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - if 0 not in idx.shape: - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states = final_hidden_states.index_add(0, top_x.int(), current_hidden_states.to(hidden_states.dtype)) - + router_scores=mint.scatter_add(router_scores,-1,selected_experts,routing_weights) shared_expert_output = self.shared_expert(hidden_states) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + if not self.copied: + self.w1=nn.Parameter(mint.stack([self.experts[i].gate_proj.weight.T for i in range(self.num_experts)]),requires_grad=False) + self.w2=nn.Parameter(mint.stack([self.experts[i].down_proj.weight.T for i in range(self.num_experts)]),requires_grad=False) + self.w3=nn.Parameter(mint.stack([self.experts[i].up_proj.weight.T for i in range(self.num_experts)]),requires_grad=False) + self.copied=True + del self.experts + gc.collect() + self.experts=None - final_hidden_states = final_hidden_states + shared_expert_output + hidden_w1=mint.matmul(hidden_states, self.w1) + hidden_w3=mint.matmul(hidden_states, self.w3) + hidden_states = self.act(hidden_w1)*hidden_w3 + hidden_states = mint.bmm(hidden_states, self.w2)*ops.transpose(router_scores,0,1).unsqueeze(-1) + hidden_states = mint.sum(hidden_states,dim=0) + hidden_states = hidden_states + shared_expert_output + return hidden_states.view(batch_size, sequence_length, hidden_dim), router_logits - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) - return final_hidden_states, router_logits class Qwen2MoeDecoderLayer(nn.Module): @@ -485,7 +454,7 @@ def __init__(self, config: Qwen2MoeConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = Qwen2Moe_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) if (layer_idx not in config.mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 @@ -514,16 +483,16 @@ def forward( hidden_states (`mindspore.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`mindspore.Tensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(mindspore.Tensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_router_logits (`bool`, *optional*): - Whether or not to return the logits of all the routers. They are useful for computing the router loss, - and should not be returned during inference. + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(mindspore.Tensor)`, *optional*): cached past key and value projection states cache_position (`mindspore.Tensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. kwargs (`dict`, *optional*): @@ -532,7 +501,6 @@ def forward( """ residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) # Self Attention @@ -550,13 +518,11 @@ def forward( # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) if isinstance(hidden_states, tuple): hidden_states, router_logits = hidden_states else: router_logits = None - hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -572,7 +538,14 @@ def forward( return outputs +def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: + if variant is not None: + splits = weights_name.split(".") + splits = splits[:-1] + [variant] + splits[-1:] + weights_name = ".".join(splits) + return weights_name +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Qwen2Moe class Qwen2MoePreTrainedModel(PreTrainedModel): config_class = Qwen2MoeConfig base_model_prefix = "model" @@ -592,7 +565,245 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight[module.padding_idx] = 0 + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = save_checkpoint, + push_to_hub: bool = False, + max_shard_size: Union[int, str] = "5GB", + safe_serialization: bool = True, + variant: Optional[str] = None, + token: Optional[Union[str, bool]] = None, + save_peft_format: bool = True, + **kwargs, + ): + config=self.config + for layer in range(config.num_hidden_layers): + module=self.layers[layer].mlp if hasattr(self, 'layers') else self.model.layers[layer].mlp + if module.copied: + module.experts=nn.ModuleList([Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.num_experts)]) + for i in range(config.num_experts): + module.experts[i].gate_proj.weight = nn.Parameter(module.w1[i].T) + module.experts[i].down_proj.weight = nn.Parameter(module.w2[i].T) + module.experts[i].up_proj.weight = nn.Parameter(module.w3[i].T) + del module.w1,module.w2,module.w3 + gc.collect() + + use_auth_token = kwargs.pop("use_auth_token", None) + ignore_metadata_errors = kwargs.pop("ignore_metadata_errors", False) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False) + if "save_config" in kwargs: + warnings.warn( + "`save_config` is deprecated. Use `is_main_process` instead." + ) + is_main_process = kwargs.pop("save_config") + if safe_serialization and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + # Only save the model itself if we are using distributed training + model_to_save = unwrap_model(self) + + # save the string version of dtype to the config, e.g. convert mindspore.float32 => "float32" + # we currently don't use this setting automatically, but may start to use with v5 + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.ms_dtype = str(dtype).lower() + + # Attach architecture to the config + model_to_save.config.architectures = [model_to_save.__class__.__name__] + + # Save the config + if is_main_process: + if not _hf_peft_config_loaded: + model_to_save.config.save_pretrained(save_directory) + if self.can_generate(): + # generation config built from the model config + the model config holds generation kwargs -> generate + # may revert to legacy behavior if the two don't match + if ( + model_to_save.generation_config._from_model_config + and model_to_save.config._has_non_default_generation_parameters() + ): + new_generation_config = GenerationConfig.from_model_config(model_to_save.config) + if new_generation_config != model_to_save.generation_config: + logger.warning( + "Your generation config was originally created from the model config, but the model " + "config has changed since then. Unless you pass the `generation_config` argument to this " + "model's `generate` calls, they will revert to the legacy behavior where the base " + "`generate` parameterization is loaded from the model config instead. " + "To avoid this behavior and this warning, we recommend you to overwrite the generation " + "config model attribute before calling the model's `save_pretrained`, preferably also " + "removing any generation kwargs from the model config. This warning will be raised to an " + "exception in v4.41." + ) + model_to_save.generation_config.save_pretrained(save_directory) + + if _hf_peft_config_loaded: + logger.info( + "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved." + ) + state_dict = model_to_save.get_adapter_state_dict() + + if save_peft_format: + logger.info( + "To match the expected format of the PEFT library, all keys of the state dict of adapters will be pre-pended with `base_model.model`." + ) + peft_state_dict = {} + for key, value in state_dict.items(): + peft_state_dict[f"base_model.model.{key}"] = value + state_dict = peft_state_dict + + active_adapter = self.active_adapters() + + if len(active_adapter) > 1: + raise ValueError( + "Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one " + "by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`" + ) + active_adapter = active_adapter[0] + + current_peft_config = self.peft_config[active_adapter] + current_peft_config.save_pretrained(save_directory) + + # for offloaded modules + module_map = {} + + # Save the model + if state_dict is None: + # if any model parameters are offloaded, make module map + if ( + hasattr(self, "hf_device_map") + and len(set(self.hf_device_map.values())) > 1 + and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values()) + ): + warnings.warn( + "Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)" + ) + for name, module in model_to_save.named_modules(): + if name == "": + continue + module_state_dict = module.state_dict() + + for key in module_state_dict: + module_map[name + f".{key}"] = module + state_dict = model_to_save.state_dict() + + # Handle the case where some state_dict keys shouldn't be saved + if self._keys_to_ignore_on_save is not None: + for ignore_key in self._keys_to_ignore_on_save: + if ignore_key in state_dict.keys(): + del state_dict[ignore_key] + + # Shard the model if it is too big. + if not _hf_peft_config_loaded: + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + else: + weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME + + filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") + state_dict_split = split_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + ) + # Save index if sharded + index = None + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "") + reg = re.compile(r"(.*?)-\d{5}-of-\d{5}") + + if ( + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and filename not in state_dict_split.filename_to_tensors.keys() + and is_main_process + and reg.fullmatch(filename_no_suffix) is not None + ): + os.remove(full_filename) + # Save the model + filename_to_tensors = state_dict_split.filename_to_tensors.items() + if module_map: + filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards") + for shard_file, tensors in filename_to_tensors: + shard = {tensor: state_dict[tensor] for tensor in tensors} + # remake shard with onloaded parameters if necessary + if module_map: + # init state_dict for this shard + shard_state_dict = {name: "" for name in shard} + for module_name in shard: + module = module_map[module_name] + # update state dict with onloaded parameters + # shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict) + + # assign shard to be the completed state dict + shard = shard_state_dict + del shard_state_dict + gc.collect() + + if safe_serialization: + # At some point we will need to deal better with save_function (used for TPU and other distributed + # joyfulness), but for now this enough. + safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "np"}) + else: + save_function(shard, os.path.join(save_directory, shard_file)) + + if index is None: + path_to_weights = os.path.join(save_directory, weights_name) + logger.info(f"Model weights saved in {path_to_weights}") + else: + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + +# copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->Qwen2Moe,Mistral->Qwen2Moe class Qwen2MoeModel(Qwen2MoePreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`] @@ -614,6 +825,7 @@ def __init__(self, config: Qwen2MoeConfig): self.norm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False + self.cached_position_ids = None # Initialize weights and apply final processing self.post_init() @@ -623,6 +835,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value + # Ignore copy def forward( self, input_ids: mindspore.Tensor = None, @@ -659,6 +872,18 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False + if not past_key_values: + self.cached_position_ids = None + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if inputs_embeds.shape[1]==1 and past_key_values: + position_ids=mint.narrow(self.cached_position_ids,1,-1,1)+1 + if position_ids is not None: + if position_ids.shape[0]!=inputs_embeds.shape[0] or position_ids[0][-1]>cache_position[-1]: + position_ids = attention_mask.int().cumsum(-1) - 1 + position_ids = position_ids.masked_fill(attention_mask == 0, 1) + if past_key_values: + position_ids = mint.narrow(position_ids,1,-input_ids.shape[1],input_ids.shape[1]) use_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache) and not self.training: @@ -668,21 +893,21 @@ def forward( "We detected that you are passing `past_key_values` as a tuple and this is deprecated. " "Please use an appropriate `Cache` class" ) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = ops.arange( + cache_position = mint.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) + if inputs_embeds.shape[1]>1: + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + else: + causal_mask=(attention_mask[:, None, None, :]-1)*float(ops.finfo(inputs_embeds.dtype).min) hidden_states = inputs_embeds @@ -740,7 +965,7 @@ def forward( next_cache = None if use_cache: next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - + self.cached_position_ids=position_ids if not return_dict: return tuple( v @@ -764,7 +989,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. @@ -805,7 +1029,6 @@ def __init__(self, config): self.model = Qwen2MoeModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_experts self.num_experts_per_tok = config.num_experts_per_tok @@ -830,6 +1053,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + # Ignore copy def forward( self, input_ids: mindspore.Tensor = None, @@ -863,7 +1087,7 @@ def forward( >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="ms") + >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) @@ -875,6 +1099,7 @@ def forward( output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) + output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -894,7 +1119,6 @@ def forward( return_dict=return_dict, cache_position=cache_position, ) - hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() @@ -938,7 +1162,6 @@ def forward( router_logits=outputs.router_logits, ) - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -958,15 +1181,16 @@ def prepare_inputs_for_generation( if 0 not in input_ids.shape: input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - + input_ids = mint.narrow(input_ids,1,input_ids.shape[1]-cache_position.shape[0],cache_position.shape[0]) if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.int().cumsum(-1) - 1 - position_ids = position_ids.masked_fill(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - + if past_key_values and input_ids.shape[1]==1: + position_ids=None + else: + #create position_ids on the fly for batch generation + position_ids = attention_mask.int().cumsum(-1) - 1 + position_ids = position_ids.masked_fill(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: @@ -974,7 +1198,6 @@ def prepare_inputs_for_generation( else: # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids, "inputs_embeds": None} - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape @@ -993,7 +1216,6 @@ def prepare_inputs_for_generation( cache_position=cache_position, batch_size=batch_size, ) - model_inputs.update( { "position_ids": position_ids, @@ -1006,7 +1228,7 @@ def prepare_inputs_for_generation( return model_inputs -# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Qwen2Moe, LLAMA->Qwen2Moe class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1070,12 +1292,12 @@ def forward( else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = ops.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = mint.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] else: sequence_lengths = -1 - pooled_logits = logits[ops.arange(batch_size), sequence_lengths] + pooled_logits = logits[mint.arange(batch_size), sequence_lengths] loss = None if labels is not None: @@ -1112,7 +1334,7 @@ def forward( ) -# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->Qwen2Moe class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel): def __init__(self, config): super().__init__(config)