diff --git a/src/transformers/models/mistral/modeling_flax_mistral.py b/src/transformers/models/mistral/modeling_flax_mistral.py index 41f03a182035..c09ad6082ee0 100644 --- a/src/transformers/models/mistral/modeling_flax_mistral.py +++ b/src/transformers/models/mistral/modeling_flax_mistral.py @@ -19,7 +19,7 @@ # limitations under the License. """ Flax Mistral model.""" import math -from typing import Any, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import flax.linen as nn import jax @@ -29,12 +29,15 @@ from flax.linen.initializers import ones from flax.traverse_util import flatten_dict, unflatten_dict -from ...modeling_flax_outputs import (FlaxBaseModelOutputWithPast, - FlaxCausalLMOutputWithCrossAttentions, - FlaxSequenceClassifierOutput) +from ...modeling_flax_outputs import ( + FlaxBaseModelOutputWithPast, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSequenceClassifierOutput, +) from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, logging from .configuration_mistral import MistralConfig + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MistralConfig" @@ -506,54 +509,7 @@ def __call__( return outputs -class FlaxMistralLayerCollection(nn.Module): - - config: MistralConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.blocks = [ - FlaxMistralDecoderLayer(self.config, dtype=self.dtype, name=str(i)) for i in range(self.config.num_hidden_layers) - ] - - def __call__(self, - hidden_states: jnp.ndarray = None, - attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - past_key_values: Optional[List[jnp.ndarray]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None,) -> Any: - - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.blocks): - if output_hidden_states: - all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - if output_attentions: - all_attentions += (layer_outputs[1],) - - outputs = (hidden_states, next_decoder_cache, all_hidden_states, all_attentions) - - return outputs @@ -572,10 +528,9 @@ def setup(self): self.padding_idx = self.config.pad_token_id self.vocab_size = self.config.vocab_size self.embed_tokens = nn.Embed(self.config.vocab_size, self.config.hidden_size, dtype=self.dtype) - self.layers = FlaxMistralLayerCollection(self.config, dtype=self.dtype) - # self.layers = [ - # FlaxMistralDecoderLayer(self.config, dtype=self.dtype) for _ in range(self.config.num_hidden_layers) - # ] + self.layers = [ + FlaxMistralDecoderLayer(self.config, dtype=self.dtype) for _ in range(self.config.num_hidden_layers) + ] self.norm = FlaxMistralRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype) def __call__( @@ -646,52 +601,38 @@ def __call__( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None - - outputs = self.layers( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - # for idx, decoder_layer in enumerate(self.layers): - # if output_hidden_states: - # all_hidden_states += (hidden_states,) - # past_key_value = past_key_values[idx] if past_key_values is not None else None - # layer_outputs = decoder_layer( - # hidden_states, - # attention_mask=attention_mask, - # position_ids=position_ids, - # past_key_value=past_key_value, - # output_attentions=output_attentions, - # use_cache=use_cache, - # padding_mask=padding_mask, - # ) - - # hidden_states = layer_outputs[0] - - # if use_cache: - # next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - # if output_attentions: - # all_self_attns += (layer_outputs[1],) - - hidden_states = outputs[0] + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + past_key_value = past_key_values[idx] if past_key_values is not None else None + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: - all_hidden_states = outputs[2] + (hidden_states,) - outputs = (hidden_states, all_hidden_states) + outputs[2:] + all_hidden_states += (hidden_states,) - next_cache = outputs[1] if use_cache else None + next_cache = next_decoder_cache if use_cache else None if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, outputs[-1]] if v is not None) - + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return FlaxBaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache,