-
Notifications
You must be signed in to change notification settings - Fork 27.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generate: New
Cache
abstraction and Attention Sinks support (#26681)
* Draft version of new KV Caching This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks) / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented in a third-party or in transformers directly * Address numerous PR suggestions 1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic. 2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls. 3. Remove __bool__ and __getitem__ magic as they're confusing. 4. past_key_values.update(key, value, idx) now returns key, value. 5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR. 6. Separate key_cache and value_cache. Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method. * Implement the SinkCache through backward+forward rotations * Integrate (Sink)Cache with Llama FA2 * Set use_legacy_cache=True as default, allows for test passes * Move from/to_legacy_cache to ...Model class * Undo unnecessary newline change * Remove copy utility from deprecated OpenLlama * Match import style * manual rebase with main * Cache class working with generate (#1) * Draft version of new KV Caching This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks) / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented in a third-party or in transformers directly * Address numerous PR suggestions 1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic. 2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls. 3. Remove __bool__ and __getitem__ magic as they're confusing. 4. past_key_values.update(key, value, idx) now returns key, value. 5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR. 6. Separate key_cache and value_cache. Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method. * Integrate (Sink)Cache with Llama FA2 * Move from/to_legacy_cache to ...Model class * Undo unnecessary newline change * Match import style * working generate * Add tests; Simplify code; Apply changes to Mistral and Persimmon * fix rebase mess * a few more manual fixes * last manual fix * propagate changes to phi * upgrade test * add use_legacy_cache docstring; beef up tests * reintroduce unwanted deletes --------- Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com> * move import * add default to model_kwargs.get('use_legacy_cache') * correct failing test * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * apply PR suggestions * fix failing test * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> * PR comments * tmp commit * add docstrings * more tests, more docstrings, add to docs * derp * tmp commit * tmp dbg * more dbg * fix beam search bug * cache can be a list of tuples in some models * fix group beam search * all but sinkcache integration tests * fix sink cache and add hard integration test * now also compatible with input_embeds input * PR comments * add Cache support to Phi+FA2 * make fixup --------- Co-authored-by: Joao Gante <joao@huggingface.co> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
- Loading branch information
1 parent
0ea42ef
commit 633215b
Showing
14 changed files
with
962 additions
and
195 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,298 @@ | ||
from typing import Any, Dict, List, Optional, Tuple | ||
|
||
import torch | ||
|
||
|
||
class Cache: | ||
""" | ||
Base, abstract class for all caches. The actual data structure is specific to each subclass. | ||
""" | ||
|
||
def update( | ||
self, | ||
key_states: torch.Tensor, | ||
value_states: torch.Tensor, | ||
layer_idx: int, | ||
cache_kwargs: Optional[Dict[str, Any]] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | ||
Parameters: | ||
key_states (`torch.Tensor`): | ||
The new key states to cache. | ||
value_states (`torch.Tensor`): | ||
The new value states to cache. | ||
layer_idx (`int`): | ||
The index of the layer to cache the states for. | ||
cache_kwargs (`Dict[str, Any]`, `optional`): | ||
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of | ||
cache to be created. | ||
Return: | ||
A tuple containing the updated key and value states. | ||
""" | ||
raise NotImplementedError("Make sure to implement `update` in a subclass.") | ||
|
||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" | ||
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") | ||
|
||
|
||
class DynamicCache(Cache): | ||
""" | ||
A cache that grows dynamically as more tokens are generated. This is the default for generative models. | ||
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is | ||
`[batch_size, num_heads, seq_len, head_dim]`. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
self.key_cache: List[torch.Tensor] = [] | ||
self.value_cache: List[torch.Tensor] = [] | ||
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen | ||
|
||
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: | ||
""" | ||
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the | ||
sequence length. | ||
""" | ||
if layer_idx < len(self): | ||
return (self.key_cache[layer_idx], self.value_cache[layer_idx]) | ||
else: | ||
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") | ||
|
||
def __iter__(self): | ||
""" | ||
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over | ||
keys and values | ||
""" | ||
for layer_idx in range(len(self)): | ||
yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) | ||
|
||
def __len__(self): | ||
""" | ||
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds | ||
to the number of layers in the model. | ||
""" | ||
return len(self.key_cache) | ||
|
||
def update( | ||
self, | ||
key_states: torch.Tensor, | ||
value_states: torch.Tensor, | ||
layer_idx: int, | ||
cache_kwargs: Optional[Dict[str, Any]] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | ||
Parameters: | ||
key_states (`torch.Tensor`): | ||
The new key states to cache. | ||
value_states (`torch.Tensor`): | ||
The new value states to cache. | ||
layer_idx (`int`): | ||
The index of the layer to cache the states for. | ||
cache_kwargs (`Dict[str, Any]`, `optional`): | ||
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. | ||
Return: | ||
A tuple containing the updated key and value states. | ||
""" | ||
# Update the number of seen tokens | ||
if layer_idx == 0: | ||
self.seen_tokens += key_states.shape[-2] | ||
|
||
# Update the cache | ||
if len(self.key_cache) <= layer_idx: | ||
self.key_cache.append(key_states) | ||
self.value_cache.append(value_states) | ||
else: | ||
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) | ||
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) | ||
|
||
return self.key_cache[layer_idx], self.value_cache[layer_idx] | ||
|
||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" | ||
if len(self.key_cache) <= layer_idx: | ||
return 0 | ||
return self.key_cache[layer_idx].shape[-2] | ||
|
||
def reorder_cache(self, beam_idx: torch.LongTensor): | ||
"""Reorders the cache for beam search, given the selected beam indices.""" | ||
for layer_idx in range(len(self.key_cache)): | ||
device = self.key_cache[layer_idx].device | ||
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) | ||
device = self.value_cache[layer_idx].device | ||
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) | ||
|
||
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: | ||
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" | ||
legacy_cache = () | ||
for layer_idx in range(len(self)): | ||
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) | ||
return legacy_cache | ||
|
||
@classmethod | ||
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": | ||
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" | ||
cache = cls() | ||
if past_key_values is not None: | ||
for layer_idx in range(len(past_key_values)): | ||
key_states, value_states = past_key_values[layer_idx] | ||
cache.update(key_states, value_states, layer_idx) | ||
return cache | ||
|
||
|
||
class SinkCache(Cache): | ||
""" | ||
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to | ||
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past | ||
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. | ||
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is | ||
`[batch_size, num_heads, seq_len, head_dim]`. | ||
Parameters: | ||
window_length (`int`): | ||
The length of the context window. | ||
num_sink_tokens (`int`): | ||
The number of sink tokens. See the original paper for more information. | ||
""" | ||
|
||
def __init__(self, window_length: int, num_sink_tokens: int) -> None: | ||
self.key_cache: List[torch.Tensor] = [] | ||
self.value_cache: List[torch.Tensor] = [] | ||
self.window_length = window_length | ||
self.num_sink_tokens = num_sink_tokens | ||
self.cos_sin_cache = {} | ||
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen | ||
|
||
@staticmethod | ||
def _rotate_half(x): | ||
x1 = x[..., : x.shape[-1] // 2] | ||
x2 = x[..., x.shape[-1] // 2 :] | ||
return torch.cat((-x2, x1), dim=-1) | ||
|
||
def _apply_key_rotary_pos_emb( | ||
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor | ||
) -> torch.Tensor: | ||
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) | ||
return rotated_key_states | ||
|
||
def _get_rerotation_cos_sin( | ||
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
if key_states.shape[-2] not in self.cos_sin_cache: | ||
# Upcast to float32 temporarily for better accuracy | ||
cos = cos.to(torch.float32) | ||
sin = sin.to(torch.float32) | ||
|
||
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence | ||
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] | ||
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] | ||
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] | ||
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] | ||
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin | ||
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin | ||
|
||
self.cos_sin_cache[key_states.shape[-2]] = ( | ||
rerotation_cos.to(key_states.dtype).unsqueeze(0), | ||
rerotation_sin.to(key_states.dtype).unsqueeze(0), | ||
) | ||
return self.cos_sin_cache[key_states.shape[-2]] | ||
|
||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" | ||
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length | ||
if len(self.key_cache) <= layer_idx: | ||
return 0 | ||
cache_length = self.key_cache[layer_idx].shape[-2] | ||
return min(cache_length, self.window_length - 1) | ||
|
||
def update( | ||
self, | ||
key_states: torch.Tensor, | ||
value_states: torch.Tensor, | ||
layer_idx: int, | ||
cache_kwargs: Optional[Dict[str, Any]] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | ||
Parameters: | ||
key_states (`torch.Tensor`): | ||
The new key states to cache. | ||
value_states (`torch.Tensor`): | ||
The new value states to cache. | ||
layer_idx (`int`): | ||
The index of the layer to cache the states for. | ||
cache_kwargs (`Dict[str, Any]`, `optional`): | ||
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, | ||
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the | ||
rotation as the tokens are shifted. | ||
Return: | ||
A tuple containing the updated key and value states. | ||
""" | ||
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models | ||
# with partially rotated position embeddings, like Phi or Persimmon. | ||
sin = cache_kwargs.get("sin") | ||
cos = cache_kwargs.get("cos") | ||
partial_rotation_size = cache_kwargs.get("partial_rotation_size") | ||
using_rope = cos is not None and sin is not None | ||
|
||
# Update the number of seen tokens | ||
if layer_idx == 0: | ||
self.seen_tokens += key_states.shape[-2] | ||
|
||
# [bsz, num_heads, seq_len, head_dim] | ||
if len(self.key_cache) <= layer_idx: | ||
# Empty cache | ||
self.key_cache.append(key_states) | ||
self.value_cache.append(value_states) | ||
|
||
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: | ||
# Growing cache | ||
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) | ||
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) | ||
|
||
else: | ||
# Shifting cache | ||
keys_to_keep = self.key_cache[layer_idx][ | ||
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : | ||
] | ||
|
||
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted | ||
if using_rope: | ||
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(key_states, cos, sin) | ||
if partial_rotation_size is not None: | ||
keys_to_keep, keys_pass = ( | ||
keys_to_keep[..., :partial_rotation_size], | ||
keys_to_keep[..., partial_rotation_size:], | ||
) | ||
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) | ||
if partial_rotation_size is not None: | ||
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) | ||
|
||
# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens | ||
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] | ||
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) | ||
|
||
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] | ||
values_to_keep = self.value_cache[layer_idx][ | ||
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : | ||
] | ||
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2) | ||
|
||
return self.key_cache[layer_idx], self.value_cache[layer_idx] | ||
|
||
def reorder_cache(self, beam_idx: torch.LongTensor): | ||
"""Reorders the cache for beam search, given the selected beam indices.""" | ||
for layer_idx in range(len(self.key_cache)): | ||
device = self.key_cache[layer_idx].device | ||
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) | ||
device = self.value_cache[layer_idx].device | ||
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) |
Oops, something went wrong.