Skip to content

Commit

Permalink
Cache class working with generate (#1)
Browse files Browse the repository at this point in the history
* Draft version of new KV Caching

This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks)
/ StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented
in a third-party or in transformers directly

* Address numerous PR suggestions

1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic.
2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls.
3. Remove __bool__ and __getitem__ magic as they're confusing.
4. past_key_values.update(key, value, idx) now returns key, value.
5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR.
6. Separate key_cache and value_cache.

Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method.

* Integrate (Sink)Cache with Llama FA2

* Move from/to_legacy_cache to ...Model class

* Undo unnecessary newline change

* Match import style

* working generate

* Add tests; Simplify code; Apply changes to Mistral and Persimmon

* fix rebase mess

* a few more manual fixes

* last manual fix

* propagate changes to phi

* upgrade test

* add use_legacy_cache docstring; beef up tests

* reintroduce unwanted deletes

---------

Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>
  • Loading branch information
2 people authored and ydshieh committed Dec 7, 2023
1 parent 46da03d commit 558541a
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 107 deletions.
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,7 @@
_import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["cache_utils"] = []
_import_structure["data.datasets"] = [
"GlueDataset",
"GlueDataTrainingArguments",
Expand Down
60 changes: 41 additions & 19 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,40 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TypeVar
from typing import List, Optional, Tuple

import torch


T = TypeVar("T")


class Cache(ABC):
class Cache:
def __init__(self) -> None:
self.key_cache: Dict[int, Tuple[torch.Tensor]] = {}
self.value_cache: Dict[int, Tuple[torch.Tensor]] = {}
self.key_cache: List[Tuple[torch.Tensor]] = []
self.value_cache: List[Tuple[torch.Tensor]] = []

def __getitem__(self, key: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if key == 0:
return self.key_cache
elif key == 1:
return self.value_cache
else:
raise KeyError(f"Cache only supports 0 (key) and 1 (value) indexing, got {key}")

def __iter__(self):
"""
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
keys and values
"""
yield self.key_cache
yield self.value_cache

def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.key_cache)

@abstractmethod
def update(
self,
key_states: torch.Tensor,
Expand All @@ -21,10 +43,10 @@ def update(
cos: Optional[torch.Tensor] = None,
sin: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
pass
raise NotImplementedError("Make sure to implement `update` in a subclass.")

def get_seq_length(self, layer_idx: int = 0) -> int:
if layer_idx not in self.key_cache:
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]

Expand Down Expand Up @@ -53,9 +75,9 @@ def update(
cos: Optional[torch.Tensor] = None,
sin: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if layer_idx not in self.key_cache:
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
Expand Down Expand Up @@ -109,7 +131,7 @@ def get_rerotation_cos_sin(
)
return self.cos_sin_cache[key_states.shape[-2]]

def get_seq_length(self, layer_idx: int = 0) -> int:
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
return min(super().get_seq_length(layer_idx), self.window_length - 1)

Expand All @@ -122,10 +144,10 @@ def update(
sin: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# [bsz, num_heads, seq_len, head_dim]
if layer_idx not in self.key_cache:
if len(self.key_cache) <= layer_idx:
# Empty cache
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
self.key_cache.append(key_states)
self.value_cache.append(value_states)

elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
# Growing cache
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch.distributed as dist
from torch import nn

from ..cache_utils import DynamicCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
Expand Down Expand Up @@ -3226,6 +3227,8 @@ def beam_search(
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
if not model_kwargs.get("use_legacy_cache"):
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(model_kwargs["past_key_values"])

if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
Expand Down Expand Up @@ -3561,6 +3564,8 @@ def beam_sample(
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
if not model_kwargs.get("use_legacy_cache"):
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(model_kwargs["past_key_values"])

if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
Expand Down Expand Up @@ -3948,6 +3953,8 @@ def group_beam_search(
model_kwargs["past_key_values"] = self._reorder_cache(
model_kwargs["past_key_values"], reordering_indices
)
if not model_kwargs.get("use_legacy_cache"):
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(model_kwargs["past_key_values"])

# increase cur_len
cur_len = cur_len + 1
Expand Down Expand Up @@ -4288,6 +4295,8 @@ def constrained_beam_search(
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
if not model_kwargs.get("use_legacy_cache"):
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(model_kwargs["past_key_values"])

if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
Expand Down
23 changes: 10 additions & 13 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,11 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.config = config
self.attention_dropout = config.attention_dropout
self.layer_idx = layer_idx
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
Expand Down Expand Up @@ -435,7 +435,7 @@ def forward(
if not output_attentions:
attn_weights = None

return attn_output, attn_weights, (past_key_value if use_cache else None)
return attn_output, attn_weights, past_key_value


class LlamaFlashAttention2(LlamaAttention):
Expand Down Expand Up @@ -539,7 +539,7 @@ def forward(
if not output_attentions:
attn_weights = None

return attn_output, attn_weights, (past_key_value if use_cache else None)
return attn_output, attn_weights, past_key_value

def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
Expand Down Expand Up @@ -640,7 +640,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query


class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = (
Expand Down Expand Up @@ -816,6 +816,9 @@ def _init_weights(self, module):
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
use_legacy_cache (`bool`, *optional*):
If set to `True` (default), will return `past_key_values` as described input above. Otherwise, will return
a subclass of `Cache`
"""


Expand Down Expand Up @@ -887,7 +890,7 @@ def forward(
past_key_values_length = 0
if use_cache:
if not isinstance(past_key_values, Cache):
past_key_values = self.from_legacy_cache(past_key_values)
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_seq_length()

if position_ids is None:
Expand Down Expand Up @@ -964,7 +967,7 @@ def forward(

next_cache = None
if use_cache:
next_cache = self.to_legacy_cache(next_decoder_cache) if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand All @@ -974,12 +977,6 @@ def forward(
attentions=all_self_attns,
)

def from_legacy_cache(self, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]) -> Cache:
return DynamicCache.from_legacy_cache(past_key_values)

def to_legacy_cache(self, past_key_values: Cache) -> Tuple[Tuple[torch.Tensor]]:
return past_key_values.to_legacy_cache()


class LlamaForCausalLM(LlamaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
Expand Down
Loading

0 comments on commit 558541a

Please sign in to comment.