Skip to content

Commit

Permalink
Merge branch 'main' into enable-ort-gpu-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil authored Jan 29, 2025
2 parents 24d682e + d1bcdf7 commit def5fdb
Show file tree
Hide file tree
Showing 7 changed files with 273 additions and 116 deletions.
95 changes: 95 additions & 0 deletions optimum/exporters/onnx/_traceable_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import logging
from typing import Any, Dict, Optional, Tuple

import torch


logger = logging.getLogger(__name__)


# Simply removing the nn.Module, same as in https://github.com/huggingface/transformers/pull/35873
class TraceableCache:
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""

def __init__(self):
super().__init__()

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
cache to be created.
Return:
A tuple containing the updated key and value states.
"""
raise NotImplementedError("Make sure to implement `update` in a subclass.")

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")

# Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length"
# Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles
# infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so
# we change naming to be more explicit
def get_max_length(self) -> Optional[int]:
logger.warning_once(
"`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. "
"Calling `get_max_cache()` will raise error from v4.48"
)
return self.get_max_cache_shape()

def get_max_cache_shape(self) -> Optional[int]:
"""Returns the maximum sequence length (i.e. max capacity) of the cache object"""
raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.")

def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
# Cache without size limit -> all cache is usable
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
# length, we will need to evict part of the cache (and thus not all cache is usable)
max_length = self.get_max_cache_shape()
previous_seq_length = self.get_seq_length(layer_idx)
if max_length is not None and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx] != []:
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
if self.value_cache[layer_idx] != []:
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

@property
def seen_tokens(self):
logger.warning_once(
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
"model input instead."
)
if hasattr(self, "_seen_tokens"):
return self._seen_tokens
else:
return None
7 changes: 3 additions & 4 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ class DeiTOnnxConfig(ViTOnnxConfig):


class BeitOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class ConvNextOnnxConfig(ViTOnnxConfig):
Expand Down Expand Up @@ -1598,13 +1598,12 @@ class Data2VecTextOnnxConfig(DistilBertOnnxConfig):


class Data2VecVisionOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class Data2VecAudioOnnxConfig(AudioOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedConfig
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
NORMALIZED_CONFIG_CLASS = NormalizedConfig


class PerceiverDummyInputGenerator(DummyVisionInputGenerator):
Expand Down
Loading

0 comments on commit def5fdb

Please sign in to comment.