Skip to content

Commit

Permalink
SlidingWindowCache: reduce differences to other Cache classes (#30970)
Browse files Browse the repository at this point in the history
* tmp commit

* sliding window with fewer differences

* make fixup + rebase

* missing overwrite
  • Loading branch information
gante authored Jun 3, 2024
1 parent 221aaec commit d475f76
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 66 deletions.
81 changes: 25 additions & 56 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,22 +829,22 @@ def reset(self):
self.value_cache[layer_idx].zero_()


class SlidingWindowCache(Cache):
class SlidingWindowCache(StaticCache):
"""
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window_size - 1`,
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`,
if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.
The `to_shift` is only true once we are above sliding_window_size. Thus with `sliding_window_size==64`:
The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window_size
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window_size`)
We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`)
Parameters:
config (`PretrainedConfig):
Expand All @@ -866,38 +866,11 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)

super().__init__()
self.max_batch_size = max_batch_size
# take the minimum of max_cache_len and config.sliding_window so that we allocate less memory
# when we do short-sentence generation
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.model_sliding_window_size = config.sliding_window
self.sliding_window_size = min(self.max_cache_len, self.model_sliding_window_size)
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
max_cache_len = min(config.sliding_window, max_cache_len)
super().__init__(
config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype
)

self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)

cache_shape = (
config.num_hidden_layers,
max_batch_size,
self.num_key_value_heads,
self.sliding_window_size,
self.head_dim,
)

self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)

torch._dynamo.mark_static_address(self.key_cache)
torch._dynamo.mark_static_address(self.value_cache)

def update(
self,
key_states: torch.Tensor,
Expand All @@ -909,42 +882,38 @@ def update(
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

# assume this only happens in prefill phase when prompt length > sliding_window_size
if cache_position.shape[0] > self.sliding_window_size:
k_out = key_states[:, :, -self.sliding_window_size :, :]
v_out = value_states[:, :, -self.sliding_window_size :, :]
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
# assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
if cache_position.shape[0] > self.max_cache_len:
k_out = key_states[:, :, -self.max_cache_len :, :]
v_out = value_states[:, :, -self.max_cache_len :, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states

slicing = torch.ones(self.sliding_window_size, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, self.sliding_window_size - 1)
to_shift = cache_position >= self.sliding_window_size - 1
indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size
slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, self.max_cache_len - 1)
to_shift = cache_position >= self.max_cache_len - 1
indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len

k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]

k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states

self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

return k_out, v_out
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
# assume this will be called only in the first generation step
# `cache_postion` will be used in other cases
return 0
return k_out, v_out

def get_max_length(self) -> Optional[int]:
# in theory there is no limit because the sliding window size is fixed
# no matter how long the sentence is
return None

def reset(self):
self.key_cache.zero_()
self.value_cache.zero_()
13 changes: 5 additions & 8 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,18 +1368,15 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l
Returns the resulting cache object.
"""
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
if cache_implementation == "sliding_window":
max_cache_len = min(self.config.sliding_window, max_cache_len)

need_new_cache = (
not hasattr(self, "_cache")
or (not isinstance(self._cache, cache_cls))
or self._cache.max_batch_size < max_batch_size
or self._cache.max_batch_size != max_batch_size
or self._cache.max_cache_len < max_cache_len
)
if cache_implementation == "sliding_window":
need_new_cache = need_new_cache or (
self._cache.sliding_window_size < self._cache.model_sliding_window_size
and max_cache_len > self._cache.max_cache_len
)
elif cache_implementation == "static":
need_new_cache = need_new_cache or self._cache.max_cache_len < max_cache_len

if need_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,9 +1290,9 @@ def prepare_inputs_for_generation(
past_length > 0
and attention_mask is not None
and isinstance(past_key_values, SlidingWindowCache)
and attention_mask.shape[1] > past_key_values.sliding_window_size
and attention_mask.shape[1] > past_key_values.max_cache_len
):
attention_mask = attention_mask[:, -past_key_values.sliding_window_size :]
attention_mask = attention_mask[:, -past_key_values.max_cache_len :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
Expand Down

0 comments on commit d475f76

Please sign in to comment.