-
Notifications
You must be signed in to change notification settings - Fork 494
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into enable-ort-gpu-tests
- Loading branch information
Showing
7 changed files
with
273 additions
and
116 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import logging | ||
from typing import Any, Dict, Optional, Tuple | ||
|
||
import torch | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# Simply removing the nn.Module, same as in https://github.com/huggingface/transformers/pull/35873 | ||
class TraceableCache: | ||
""" | ||
Base, abstract class for all caches. The actual data structure is specific to each subclass. | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
def update( | ||
self, | ||
key_states: torch.Tensor, | ||
value_states: torch.Tensor, | ||
layer_idx: int, | ||
cache_kwargs: Optional[Dict[str, Any]] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | ||
Parameters: | ||
key_states (`torch.Tensor`): | ||
The new key states to cache. | ||
value_states (`torch.Tensor`): | ||
The new value states to cache. | ||
layer_idx (`int`): | ||
The index of the layer to cache the states for. | ||
cache_kwargs (`Dict[str, Any]`, `optional`): | ||
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of | ||
cache to be created. | ||
Return: | ||
A tuple containing the updated key and value states. | ||
""" | ||
raise NotImplementedError("Make sure to implement `update` in a subclass.") | ||
|
||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" | ||
# TODO: deprecate this function in favor of `cache_position` | ||
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") | ||
|
||
# Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length" | ||
# Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles | ||
# infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so | ||
# we change naming to be more explicit | ||
def get_max_length(self) -> Optional[int]: | ||
logger.warning_once( | ||
"`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. " | ||
"Calling `get_max_cache()` will raise error from v4.48" | ||
) | ||
return self.get_max_cache_shape() | ||
|
||
def get_max_cache_shape(self) -> Optional[int]: | ||
"""Returns the maximum sequence length (i.e. max capacity) of the cache object""" | ||
raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") | ||
|
||
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: | ||
"""Given the sequence length of the new inputs, returns the usable length of the cache.""" | ||
# Cache without size limit -> all cache is usable | ||
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache | ||
# length, we will need to evict part of the cache (and thus not all cache is usable) | ||
max_length = self.get_max_cache_shape() | ||
previous_seq_length = self.get_seq_length(layer_idx) | ||
if max_length is not None and previous_seq_length + new_seq_length > max_length: | ||
return max_length - new_seq_length | ||
return previous_seq_length | ||
|
||
def reorder_cache(self, beam_idx: torch.LongTensor): | ||
"""Reorders the cache for beam search, given the selected beam indices.""" | ||
for layer_idx in range(len(self.key_cache)): | ||
if self.key_cache[layer_idx] != []: | ||
device = self.key_cache[layer_idx].device | ||
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) | ||
if self.value_cache[layer_idx] != []: | ||
device = self.value_cache[layer_idx].device | ||
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) | ||
|
||
@property | ||
def seen_tokens(self): | ||
logger.warning_once( | ||
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " | ||
"model input instead." | ||
) | ||
if hasattr(self, "_seen_tokens"): | ||
return self._seen_tokens | ||
else: | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.