Skip to content

Commit

Permalink
fix llama rope (wenet-e2e#2459)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Apr 5, 2024
1 parent 4d12918 commit 648fee8
Showing 1 changed file with 0 additions and 12 deletions.
12 changes: 0 additions & 12 deletions wenet/utils/rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,8 @@ def google_apply_rotary_emb(x: torch.Tensor,
return x_out


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape[2:] == (x.shape[1], x.shape[-1])
# 2 is seq_len in wenet
shape = [
d if i == 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
]
return freqs_cis.view(*shape)


def llama_apply_rotary_emb(x: torch.Tensor,
freqs_cis: torch.Tensor) -> torch.Tensor:
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, x_)
x_out = torch.view_as_real(x_ * freqs_cis).flatten(3)
return x_out.type_as(x)

0 comments on commit 648fee8

Please sign in to comment.