Skip to content

Commit

Permalink
Add sdpa arg comments (#5323)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #5323

Reviewed By: JacobSzwejbka

Differential Revision: D62623249

Pulled By: dvorjackz

fbshipit-source-id: 468abd913a4dcb9b2474ec34881cfcec2654a024
  • Loading branch information
jackzhxng authored and facebook-github-bot committed Sep 16, 2024
1 parent 08f16d0 commit ef31608
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@ def __init__(
def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim)
bsz,
seqlen,
mask: torch.Tensor,
Expand Down

0 comments on commit ef31608

Please sign in to comment.