33
33
from ..modeling_all_models import KVCache , Matmul , apply_customized_rope_module
34
34
from .configuration_llama import LlamaConfig
35
35
36
-
37
36
try :
38
37
from habana_frameworks .torch .hpex .kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE # noqa
39
38
58
57
59
58
import habana_frameworks .torch .core as htcore
60
59
61
-
62
60
def gaudi_llama_rmsnorm_forward (self , hidden_states ):
63
61
"""
64
62
Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
@@ -384,6 +382,22 @@ def forward(
384
382
padding_side ,
385
383
)
386
384
385
+ class LlamaKVCache (KVCache ):
386
+ @staticmethod
387
+ def update (prev , cur , dim , idx , inp_seq_len ):
388
+ orig_cur = cur
389
+ if prev .shape == cur .shape :
390
+ prev .copy_ (cur )
391
+ return orig_cur
392
+ if idx is not None and cur .shape [2 ] > 1 and cur .shape [2 ] <= prev .shape [2 ]:
393
+ # Initialize
394
+ prev [:, :, :inp_seq_len , :].copy_ (cur )
395
+ return orig_cur
396
+ if idx is not None :
397
+ prev .index_copy_ (dim , idx - 1 , cur )
398
+ return prev
399
+ else :
400
+ return torch .cat ((prev , cur ), dim = dim )
387
401
388
402
def GaudiDistributedAttention (fused_scaled_dot_product_attention , fused_scaled_dot_product_attention_distributed ):
389
403
if parallel_state .sequence_parallel_is_initialized () and parallel_state .get_sequence_parallel_world_size () > 1 :
@@ -398,8 +412,8 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
398
412
399
413
self .matmul_qk = Matmul ()
400
414
self .matmul_av = Matmul ()
401
- self .k_cache = KVCache ()
402
- self .v_cache = KVCache ()
415
+ self .k_cache = LlamaKVCache ()
416
+ self .v_cache = LlamaKVCache ()
403
417
404
418
if hasattr (config , "fused_qkv" ) and config .fused_qkv :
405
419
self .num_heads = config .num_attention_heads
0 commit comments