Skip to content

Commit f8b496a

Browse files
jiminharegisss
authored andcommitted
Revert common KVCache not to check token_idx (huggingface#1594)
1 parent d3973e0 commit f8b496a

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

optimum/habana/transformers/models/llama/modeling_llama.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module
3434
from .configuration_llama import LlamaConfig
3535

36-
3736
try:
3837
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE # noqa
3938

@@ -58,7 +57,6 @@
5857

5958
import habana_frameworks.torch.core as htcore
6059

61-
6260
def gaudi_llama_rmsnorm_forward(self, hidden_states):
6361
"""
6462
Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
@@ -384,6 +382,22 @@ def forward(
384382
padding_side,
385383
)
386384

385+
class LlamaKVCache(KVCache):
386+
@staticmethod
387+
def update(prev, cur, dim, idx, inp_seq_len):
388+
orig_cur = cur
389+
if prev.shape == cur.shape:
390+
prev.copy_(cur)
391+
return orig_cur
392+
if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
393+
# Initialize
394+
prev[:, :, :inp_seq_len, :].copy_(cur)
395+
return orig_cur
396+
if idx is not None:
397+
prev.index_copy_(dim, idx - 1, cur)
398+
return prev
399+
else:
400+
return torch.cat((prev, cur), dim=dim)
387401

388402
def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed):
389403
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
@@ -398,8 +412,8 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
398412

399413
self.matmul_qk = Matmul()
400414
self.matmul_av = Matmul()
401-
self.k_cache = KVCache()
402-
self.v_cache = KVCache()
415+
self.k_cache = LlamaKVCache()
416+
self.v_cache = LlamaKVCache()
403417

404418
if hasattr(config, "fused_qkv") and config.fused_qkv:
405419
self.num_heads = config.num_attention_heads

optimum/habana/transformers/models/modeling_all_models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,11 @@ def update(prev, cur, dim, idx, inp_seq_len):
5959
if prev.shape == cur.shape:
6060
prev.copy_(cur)
6161
return orig_cur
62-
if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
62+
if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
6363
# Initialize
6464
prev[:, :, :inp_seq_len, :].copy_(cur)
6565
return orig_cur
66+
assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}"
6667
if idx is not None:
6768
prev.index_copy_(dim, idx - 1, cur)
6869
return prev

0 commit comments

Comments
 (0)