Skip to content

Commit

Permalink
Revert "added layer collection approach"
Browse files Browse the repository at this point in the history
This reverts commit 0e2905b.
  • Loading branch information
kiansierra committed Oct 17, 2023
1 parent 0e2905b commit fb17b61
Showing 1 changed file with 35 additions and 94 deletions.
129 changes: 35 additions & 94 deletions src/transformers/models/mistral/modeling_flax_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# limitations under the License.
""" Flax Mistral model."""
import math
from typing import Any, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import flax.linen as nn
import jax
Expand All @@ -29,12 +29,15 @@
from flax.linen.initializers import ones
from flax.traverse_util import flatten_dict, unflatten_dict

from ...modeling_flax_outputs import (FlaxBaseModelOutputWithPast,
FlaxCausalLMOutputWithCrossAttentions,
FlaxSequenceClassifierOutput)
from ...modeling_flax_outputs import (
FlaxBaseModelOutputWithPast,
FlaxCausalLMOutputWithCrossAttentions,
FlaxSequenceClassifierOutput,
)
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, logging
from .configuration_mistral import MistralConfig


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "MistralConfig"
Expand Down Expand Up @@ -506,54 +509,7 @@ def __call__(
return outputs


class FlaxMistralLayerCollection(nn.Module):

config: MistralConfig
dtype: jnp.dtype = jnp.float32

def setup(self):
self.blocks = [
FlaxMistralDecoderLayer(self.config, dtype=self.dtype, name=str(i)) for i in range(self.config.num_hidden_layers)
]

def __call__(self,
hidden_states: jnp.ndarray = None,
attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
past_key_values: Optional[List[jnp.ndarray]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,) -> Any:

all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
next_decoder_cache = () if use_cache else None

for idx, decoder_layer in enumerate(self.blocks):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

if output_attentions:
all_attentions += (layer_outputs[1],)

outputs = (hidden_states, next_decoder_cache, all_hidden_states, all_attentions)

return outputs



Expand All @@ -572,10 +528,9 @@ def setup(self):
self.padding_idx = self.config.pad_token_id
self.vocab_size = self.config.vocab_size
self.embed_tokens = nn.Embed(self.config.vocab_size, self.config.hidden_size, dtype=self.dtype)
self.layers = FlaxMistralLayerCollection(self.config, dtype=self.dtype)
# self.layers = [
# FlaxMistralDecoderLayer(self.config, dtype=self.dtype) for _ in range(self.config.num_hidden_layers)
# ]
self.layers = [
FlaxMistralDecoderLayer(self.config, dtype=self.dtype) for _ in range(self.config.num_hidden_layers)
]
self.norm = FlaxMistralRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype)

def __call__(
Expand Down Expand Up @@ -646,52 +601,38 @@ def __call__(
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None

outputs = self.layers(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

# for idx, decoder_layer in enumerate(self.layers):
# if output_hidden_states:
# all_hidden_states += (hidden_states,)
# past_key_value = past_key_values[idx] if past_key_values is not None else None
# layer_outputs = decoder_layer(
# hidden_states,
# attention_mask=attention_mask,
# position_ids=position_ids,
# past_key_value=past_key_value,
# output_attentions=output_attentions,
# use_cache=use_cache,
# padding_mask=padding_mask,
# )

# hidden_states = layer_outputs[0]

# if use_cache:
# next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

# if output_attentions:
# all_self_attns += (layer_outputs[1],)

hidden_states = outputs[0]
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

if output_attentions:
all_self_attns += (layer_outputs[1],)

hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states = outputs[2] + (hidden_states,)
outputs = (hidden_states, all_hidden_states) + outputs[2:]
all_hidden_states += (hidden_states,)

next_cache = outputs[1] if use_cache else None
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, outputs[-1]] if v is not None)

return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return FlaxBaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
Expand Down

0 comments on commit fb17b61

Please sign in to comment.