Skip to content

Commit

Permalink
Use Index_copy method to update static cache inplace and avoid recomp…
Browse files Browse the repository at this point in the history
…ilation during each iteration in XLA
  • Loading branch information
huzama committed May 30, 2024
1 parent 10c06b3 commit 1ad0a9a
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_quanto_available, logging
from .utils import is_hqq_available, is_quanto_available, logging, is_torch_xla_available

if is_quanto_available():
from quanto import QBitsTensor, qint2, qint4
Expand Down Expand Up @@ -791,6 +791,23 @@ def update(
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

if is_torch_xla_available(): # If torch_xla is available, do out-of-place operation on KV_Cache and create a new list
k_out = k_out.index_copy(2, cache_position, key_states)
v_out = v_out.index_copy(2, cache_position, value_states)

updated_key_cache = [
k_out if i == layer_idx else self.key_cache[i] for i in range(len(self.key_cache))
]

updated_value_cache = [
v_out if i == layer_idx else self.value_cache[i] for i in range(len(self.value_cache))
]

self.key_cache = updated_key_cache
self.value_cache = updated_value_cache

return k_out, v_out

k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)

Expand Down

0 comments on commit 1ad0a9a

Please sign in to comment.