Skip to content

Commit

Permalink
fix #4160
Browse files Browse the repository at this point in the history
The split heads should be concatenated in dim=2
  • Loading branch information
hiyouga committed Jun 10, 2024
1 parent 949e990 commit a793e84
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/llamafactory/model/model_utils/longlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def shift(state: torch.Tensor) -> torch.Tensor:
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
),
dim=2,
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
Expand Down Expand Up @@ -194,7 +195,8 @@ def shift(state: torch.Tensor) -> torch.Tensor:
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
),
dim=2,
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
Expand Down Expand Up @@ -293,7 +295,8 @@ def shift(state: torch.Tensor) -> torch.Tensor:
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
),
dim=2,
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
Expand All @@ -303,7 +306,7 @@ def shift(state: torch.Tensor) -> torch.Tensor:


def _apply_llama_patch() -> None:
require_version("transformers==4.40.2", "To fix: pip install transformers==4.40.2")
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
Expand Down

0 comments on commit a793e84

Please sign in to comment.