Skip to content

Commit

Permalink
Inherit StaticCacheXLA from StaticCache instead for compatibilty with…
Browse files Browse the repository at this point in the history
… isinstance(past_key_value, StaticCache)
  • Loading branch information
huzama committed Jun 10, 2024
1 parent 9d67ac1 commit 71efd4c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def reset(self):
self.key_cache.zero_()
self.value_cache.zero_()

class StaticCacheXLA(Cache):
class StaticCacheXLA(StaticCache):
"""
Static Cache class to be used with `torch.compile(model)`.
Expand All @@ -953,7 +953,7 @@ class StaticCacheXLA(Cache):
"""

def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
super().__init__()
super().__init__(config, max_batch_size, max_cache_len, device, dtype)
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
Expand Down

0 comments on commit 71efd4c

Please sign in to comment.