From ef316082ce745b13ac698cbc28327d5a18abfb4e Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Sun, 15 Sep 2024 23:59:50 -0700 Subject: [PATCH] Add sdpa arg comments (#5323) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5323 Reviewed By: JacobSzwejbka Differential Revision: D62623249 Pulled By: dvorjackz fbshipit-source-id: 468abd913a4dcb9b2474ec34881cfcec2654a024 --- examples/models/llama2/llama_transformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 534d90c6ed..3c75b9c75f 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -218,9 +218,9 @@ def __init__( def forward( self, input_pos: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim) + k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim) + v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim) bsz, seqlen, mask: torch.Tensor,