From 558541a11fa652c0e6b9e7989a236240ab2d6488 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 21 Nov 2023 20:13:17 +0000 Subject: [PATCH] Cache class working with generate (#1) * Draft version of new KV Caching This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks) / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented in a third-party or in transformers directly * Address numerous PR suggestions 1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic. 2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls. 3. Remove __bool__ and __getitem__ magic as they're confusing. 4. past_key_values.update(key, value, idx) now returns key, value. 5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR. 6. Separate key_cache and value_cache. Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method. * Integrate (Sink)Cache with Llama FA2 * Move from/to_legacy_cache to ...Model class * Undo unnecessary newline change * Match import style * working generate * Add tests; Simplify code; Apply changes to Mistral and Persimmon * fix rebase mess * a few more manual fixes * last manual fix * propagate changes to phi * upgrade test * add use_legacy_cache docstring; beef up tests * reintroduce unwanted deletes --------- Co-authored-by: Tom Aarsen --- src/transformers/__init__.py | 1 + src/transformers/cache_utils.py | 60 +++++++++++------ src/transformers/generation/utils.py | 9 +++ .../models/llama/modeling_llama.py | 23 +++---- .../models/mistral/modeling_mistral.py | 67 +++++++++++-------- .../models/persimmon/modeling_persimmon.py | 56 ++++++++++------ src/transformers/models/phi/modeling_phi.py | 57 +++++++++------- tests/generation/test_utils.py | 66 +++++++++++++++++- 8 files changed, 232 insertions(+), 107 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f88ed87b68b8..d9f36d32f725 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1303,6 +1303,7 @@ _import_structure["activations"] = [] _import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"] _import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"] + _import_structure["cache_utils"] = [] _import_structure["data.datasets"] = [ "GlueDataset", "GlueDataTrainingArguments", diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 832cc90c8473..07ba4ee02d21 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1,18 +1,40 @@ -from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, TypeVar +from typing import List, Optional, Tuple import torch -T = TypeVar("T") - - -class Cache(ABC): +class Cache: def __init__(self) -> None: - self.key_cache: Dict[int, Tuple[torch.Tensor]] = {} - self.value_cache: Dict[int, Tuple[torch.Tensor]] = {} + self.key_cache: List[Tuple[torch.Tensor]] = [] + self.value_cache: List[Tuple[torch.Tensor]] = [] + + def __getitem__(self, key: int) -> List[Tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if key == 0: + return self.key_cache + elif key == 1: + return self.value_cache + else: + raise KeyError(f"Cache only supports 0 (key) and 1 (value) indexing, got {key}") + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + yield self.key_cache + yield self.value_cache + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) - @abstractmethod def update( self, key_states: torch.Tensor, @@ -21,10 +43,10 @@ def update( cos: Optional[torch.Tensor] = None, sin: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - pass + raise NotImplementedError("Make sure to implement `update` in a subclass.") - def get_seq_length(self, layer_idx: int = 0) -> int: - if layer_idx not in self.key_cache: + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + if len(self.key_cache) <= layer_idx: return 0 return self.key_cache[layer_idx].shape[-2] @@ -53,9 +75,9 @@ def update( cos: Optional[torch.Tensor] = None, sin: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - if layer_idx not in self.key_cache: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) else: self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) @@ -109,7 +131,7 @@ def get_rerotation_cos_sin( ) return self.cos_sin_cache[key_states.shape[-2]] - def get_seq_length(self, layer_idx: int = 0) -> int: + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length return min(super().get_seq_length(layer_idx), self.window_length - 1) @@ -122,10 +144,10 @@ def update( sin: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # [bsz, num_heads, seq_len, head_dim] - if layer_idx not in self.key_cache: + if len(self.key_cache) <= layer_idx: # Empty cache - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states + self.key_cache.append(key_states) + self.value_cache.append(value_states) elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: # Growing cache diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7040b98dd91c..4baed320c2fc 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -24,6 +24,7 @@ import torch.distributed as dist from torch import nn +from ..cache_utils import DynamicCache from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..models.auto import ( @@ -3226,6 +3227,8 @@ def beam_search( ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) + if not model_kwargs.get("use_legacy_cache"): + model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(model_kwargs["past_key_values"]) if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) @@ -3561,6 +3564,8 @@ def beam_sample( ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) + if not model_kwargs.get("use_legacy_cache"): + model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(model_kwargs["past_key_values"]) if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) @@ -3948,6 +3953,8 @@ def group_beam_search( model_kwargs["past_key_values"] = self._reorder_cache( model_kwargs["past_key_values"], reordering_indices ) + if not model_kwargs.get("use_legacy_cache"): + model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(model_kwargs["past_key_values"]) # increase cur_len cur_len = cur_len + 1 @@ -4288,6 +4295,8 @@ def constrained_beam_search( ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) + if not model_kwargs.get("use_legacy_cache"): + model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(model_kwargs["past_key_values"]) if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index ad72f1296dbd..4e9076481c0a 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -284,11 +284,11 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.config = config - self.attention_dropout = config.attention_dropout self.layer_idx = layer_idx + self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -435,7 +435,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, (past_key_value if use_cache else None) + return attn_output, attn_weights, past_key_value class LlamaFlashAttention2(LlamaAttention): @@ -539,7 +539,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, (past_key_value if use_cache else None) + return attn_output, attn_weights, past_key_value def _flash_attention_forward( self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None @@ -640,7 +640,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = ( @@ -816,6 +816,9 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + use_legacy_cache (`bool`, *optional*): + If set to `True` (default), will return `past_key_values` as described input above. Otherwise, will return + a subclass of `Cache` """ @@ -887,7 +890,7 @@ def forward( past_key_values_length = 0 if use_cache: if not isinstance(past_key_values, Cache): - past_key_values = self.from_legacy_cache(past_key_values) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_seq_length() if position_ids is None: @@ -964,7 +967,7 @@ def forward( next_cache = None if use_cache: - next_cache = self.to_legacy_cache(next_decoder_cache) if use_legacy_cache else next_decoder_cache + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -974,12 +977,6 @@ def forward( attentions=all_self_attns, ) - def from_legacy_cache(self, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]) -> Cache: - return DynamicCache.from_legacy_cache(past_key_values) - - def to_legacy_cache(self, past_key_values: Cache) -> Tuple[Tuple[torch.Tensor]]: - return past_key_values.to_legacy_cache() - class LlamaForCausalLM(LlamaPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 0b23303d5ef3..4eb42cad9000 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -30,6 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel @@ -195,9 +196,10 @@ class MistralAttention(nn.Module): and "Generating Long Sequences with Sparse Transformers". """ - def __init__(self, config: MistralConfig): + def __init__(self, config: MistralConfig, layer_idx: int): super().__init__() self.config = config + self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -232,7 +234,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, @@ -253,16 +255,12 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cos, sin) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -327,7 +325,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, @@ -351,7 +349,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) # Because the input can be padded, the absolute sequence length depends on the max position id. rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 @@ -394,10 +392,7 @@ def forward( attention_mask = attention_mask[:, slicing_tokens:] attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cos, sin) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -592,13 +587,13 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query class MistralDecoderLayer(nn.Module): - def __init__(self, config: MistralConfig): + def __init__(self, config: MistralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = ( - MistralAttention(config=config) + MistralAttention(config=config, layer_idx=layer_idx) if not getattr(config, "_flash_attn_2_enabled", False) - else MistralFlashAttention2(config) + else MistralFlashAttention2(config, layer_idx=layer_idx) ) self.mlp = MistralMLP(config) self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -766,6 +761,9 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + use_legacy_cache (`bool`, *optional*): + If set to `True` (default), will return `past_key_values` as described input above. Otherwise, will return + a subclass of `Cache` """ @@ -787,7 +785,9 @@ def __init__(self, config: MistralConfig): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -812,6 +812,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + use_legacy_cache: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -834,8 +835,10 @@ def forward( seq_length_with_past = seq_length past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + if use_cache: + if not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_seq_length() seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: @@ -889,21 +892,19 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, - past_key_value, + past_key_values, output_attentions, use_cache, ) @@ -912,7 +913,7 @@ def forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) @@ -920,7 +921,7 @@ def forward( hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -931,7 +932,10 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -986,6 +990,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + use_legacy_cache: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1030,6 +1035,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + use_legacy_cache=use_legacy_cache, ) hidden_states = outputs[0] @@ -1062,7 +1068,7 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, use_legacy_cache=True, **kwargs ): # Omit tokens covered by past_key_values if past_key_values: @@ -1097,6 +1103,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, + "use_legacy_cache": use_legacy_cache, } ) return model_inputs @@ -1156,6 +1163,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + use_legacy_cache: Optional[bool] = True, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1175,6 +1183,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + use_legacy_cache=use_legacy_cache, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 6a2535998da5..ad9296a9f16c 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel @@ -178,9 +179,10 @@ def forward(self, hidden_states): class PersimmonAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: PersimmonConfig): + def __init__(self, config: PersimmonConfig, layer_idx: int): super().__init__() self.config = config + self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -257,7 +259,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -280,7 +282,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # Partial rotary embedding @@ -300,11 +302,7 @@ def forward( key_states = torch.cat((key_rot, key_pass), dim=-1) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cos, sin) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) @@ -345,10 +343,10 @@ def forward( class PersimmonDecoderLayer(nn.Module): - def __init__(self, config: PersimmonConfig): + def __init__(self, config: PersimmonConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = PersimmonAttention(config=config) + self.self_attn = PersimmonAttention(config=config, layer_idx=layer_idx) self.mlp = PersimmonMLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -518,6 +516,9 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + use_legacy_cache (`bool`, *optional*): + If set to `True` (default), will return `past_key_values` as described input above. Otherwise, will return + a subclass of `Cache` """ @@ -539,7 +540,9 @@ def __init__(self, config: PersimmonConfig): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([PersimmonDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [PersimmonDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gradient_checkpointing = False @@ -564,6 +567,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + use_legacy_cache: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -586,8 +590,10 @@ def forward( seq_length_with_past = seq_length past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + if use_cache: + if not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_seq_length() seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: @@ -620,21 +626,19 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, - past_key_value, + past_key_values, output_attentions, ) else: @@ -642,7 +646,7 @@ def forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) @@ -650,7 +654,7 @@ def forward( hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -661,7 +665,10 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -723,6 +730,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + use_legacy_cache: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -767,6 +775,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + use_legacy_cache=use_legacy_cache, ) hidden_states = outputs[0] @@ -799,7 +808,7 @@ def forward( # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, use_legacy_cache=True, **kwargs ): if past_key_values is not None: past_length = past_key_values[0][0].shape[2] @@ -833,6 +842,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, + "use_legacy_cache": use_legacy_cache, } ) return model_inputs @@ -892,6 +902,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + use_legacy_cache: Optional[bool] = True, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -911,6 +922,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + use_legacy_cache=use_legacy_cache, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 44be9c749f0e..eb6b6cf3c525 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -217,9 +218,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class PhiAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: PhiConfig): + def __init__(self, config: PhiConfig, layer_idx: int): super().__init__() self.config = config + self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -296,7 +298,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -319,7 +321,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # Partial rotary embedding @@ -339,11 +341,7 @@ def forward( key_states = torch.cat((key_rot, key_pass), dim=-1) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cos, sin) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) @@ -603,12 +601,12 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query class PhiDecoderLayer(nn.Module): - def __init__(self, config: PhiConfig): + def __init__(self, config: PhiConfig, layer_idx: int): super().__init__() self.self_attn = ( - PhiAttention(config=config) + PhiAttention(config=config, layer_idx=layer_idx) if not getattr(config, "_flash_attn_2_enabled", False) - else PhiFlashAttention2(config=config) + else PhiFlashAttention2(config=config, layer_idx=layer_idx) ) self.mlp = PhiMLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -770,6 +768,9 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + use_legacy_cache (`bool`, *optional*): + If set to `True` (default), will return `past_key_values` as described input above. Otherwise, will return + a subclass of `Cache` """ @@ -792,7 +793,9 @@ def __init__(self, config: PhiConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.embed_dropout = nn.Dropout(config.embd_pdrop) - self.layers = nn.ModuleList([PhiDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gradient_checkpointing = False @@ -817,6 +820,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + use_legacy_cache: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -839,8 +843,10 @@ def forward( seq_length_with_past = seq_length past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + if use_cache: + if not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_seq_length() seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: @@ -877,21 +883,19 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, - past_key_value, + past_key_values, output_attentions, ) else: @@ -899,7 +903,7 @@ def forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) @@ -907,7 +911,7 @@ def forward( hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -918,7 +922,9 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -980,6 +986,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + use_legacy_cache: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1024,6 +1031,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + use_legacy_cache=use_legacy_cache, ) hidden_states = outputs[0] @@ -1057,7 +1065,7 @@ def forward( # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, use_legacy_cache=True, **kwargs ): if past_key_values is not None: past_length = past_key_values[0][0].shape[2] @@ -1091,6 +1099,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, + "use_legacy_cache": use_legacy_cache, } ) return model_inputs @@ -1151,6 +1160,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + use_legacy_cache: Optional[bool] = True, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1170,6 +1180,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + use_legacy_cache=use_legacy_cache, ) hidden_states = model_outputs[0] logits = self.score(hidden_states) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index f4050c582b8f..b218612418a0 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -20,8 +20,10 @@ import warnings import numpy as np +from parameterized import parameterized -from transformers import is_torch_available, pipeline +from transformers import is_torch_available, pipeline, set_seed +from transformers.cache_utils import DynamicCache from transformers.testing_utils import ( is_flaky, require_accelerate, @@ -1904,6 +1906,68 @@ def test_generate_continue_from_past_key_values(self): ) ) + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + # Tests that generating with the new format is exactly the same as the legacy one (for models that support it). + # 👉 tests with and without beam search so that we can test with and without cache reordering. + # 👉 tests with and without sampling so we can cover the most common use cases. + for model_class in self.all_generative_model_classes: + if "use_legacy_cache" not in inspect.signature(model_class.forward).parameters: + self.skipTest("This model does not support the new cache format") + + config, input_ids, attention_mask, _ = self._get_input_ids_and_config() + config.use_cache = True + config.is_decoder = True + + model = model_class(config).to(torch_device).eval() + generation_kwargs = { + "max_new_tokens": 5, + "do_sample": do_sample, + "num_beams": num_beams, + "num_return_sequences": num_beams, + "return_dict_in_generate": True, # Required to return `past_key_values` + } + + # Sets seed before calling `generate` for the case with do_sample=True + seed = torch.randint(0, 1000000, (1,)).item() + set_seed(seed) + legacy_results = model.generate( + input_ids, attention_mask=attention_mask, use_legacy_cache=True, **generation_kwargs + ) + set_seed(seed) + new_results = model.generate( + input_ids, attention_mask=attention_mask, use_legacy_cache=False, **generation_kwargs + ) + + # The two sets of generated sequences must match, despite the cache format between forward passes being + # different + self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist()) + self.assertTrue(isinstance(legacy_results.past_key_values, tuple)) + self.assertTrue(isinstance(new_results.past_key_values, DynamicCache)) + + # The contents of the two caches, when converted to the same format (in both directions!), must match + legacy_cache = legacy_results.past_key_values + new_cache_converted = new_results.past_key_values.to_legacy_cache() + for layer_idx in range(len(legacy_cache)): + for kv_idx in range(len(legacy_cache[layer_idx])): + self.assertTrue( + torch.allclose( + legacy_cache[layer_idx][kv_idx], + new_cache_converted[layer_idx][kv_idx], + ) + ) + + new_cache = new_results.past_key_values + legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values) + for layer_idx in range(len(new_cache)): + for kv_idx in range(len(new_cache[layer_idx])): + self.assertTrue( + torch.allclose( + new_cache[layer_idx][kv_idx], + legacy_cache_converted[layer_idx][kv_idx], + ) + ) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape num_sequences_in_output = batch_size * num_return_sequences