diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 534d90c6ed..3c75b9c75f 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -218,9 +218,9 @@ def __init__( def forward( self, input_pos: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim) + k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim) + v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim) bsz, seqlen, mask: torch.Tensor,