diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 72ce034ef1..55d4475a87 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -30,9 +30,10 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) -from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module +from ..modeling_all_models import Matmul, apply_customized_rope_module from .configuration_llama import LlamaConfig + try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE # noqa @@ -57,6 +58,7 @@ import habana_frameworks.torch.core as htcore + def gaudi_llama_rmsnorm_forward(self, hidden_states): """ Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -382,7 +384,23 @@ def forward( padding_side, ) -class LlamaKVCache(KVCache): + +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 + + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + @staticmethod def update(prev, cur, dim, idx, inp_seq_len): orig_cur = cur @@ -399,6 +417,15 @@ def update(prev, cur, dim, idx, inp_seq_len): else: return torch.cat((prev, cur), dim=dim) + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) + + def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed): if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: return fused_scaled_dot_product_attention_distributed @@ -412,8 +439,8 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.matmul_qk = Matmul() self.matmul_av = Matmul() - self.k_cache = LlamaKVCache() - self.v_cache = LlamaKVCache() + self.k_cache = KVCache() + self.v_cache = KVCache() if hasattr(config, "fused_qkv") and config.fused_qkv: self.num_heads = config.num_attention_heads