Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid rearranging all caches #1483

Merged
merged 5 commits into from
Jul 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def __init__(self, model: "Whisper", initial_token_length: int):
self.kv_cache = {}
self.hooks = []

key_modules = [block.attn.key for block in self.model.decoder.blocks]
value_modules = [block.attn.value for block in self.model.decoder.blocks]
self.kv_modules = key_modules + value_modules

def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
if not self.kv_cache:
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
Expand All @@ -164,9 +168,10 @@ def cleanup_caching(self):
self.hooks = []

def rearrange_kv_cache(self, source_indices):
for module, tensor in self.kv_cache.items():
# update the key/value cache to contain the selected sequences
self.kv_cache[module] = tensor[source_indices].detach()
if source_indices != list(range(len(source_indices))):
for module in self.kv_modules:
# update the key/value cache to contain the selected sequences
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()


class SequenceRanker:
Expand Down Expand Up @@ -668,7 +673,6 @@ def _detect_language(self, audio_features: Tensor, tokens: Tensor):
return languages, lang_probs

def _main_loop(self, audio_features: Tensor, tokens: Tensor):
assert audio_features.shape[0] == tokens.shape[0]
n_batch = tokens.shape[0]
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
no_speech_probs = [np.nan] * n_batch
Expand Down Expand Up @@ -721,8 +725,7 @@ def run(self, mel: Tensor) -> List[DecodingResult]:
)
]

# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangchou I was wondering if you remove the repeat of audio_features, where you repeat the kv_cache for cross attention? Otherwise, during cross_attention, q@k seems with mismatch dims since tokens are repeated according the beam_size.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jongwook, would you mind checking this please? Thanks.

Copy link
Contributor Author

@wangchou wangchou Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuekaizhang @ operator(matmul) should support broadcasting?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
q = torch.ones(70, 4, 16, 4)
k = torch.ones(7, 4, 4, 16)
k2 = torch.ones(70, 4, 4, 16)

context2 = q @ k2
print(context2.shape)

context1 = q @ k
print(context1.shape)

I ran with torch==2.0.1

RuntimeError: The size of tensor a (70) must match the size of tensor b (7) at non-singleton dimension 0

Copy link
Contributor Author

@wangchou wangchou Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuekaizhang I ran whisper with beam_size=5. It works.

python -m whisper ../samples/thatBand2ch_short.wav --language ja --model small --beam_size=5

After adding print in qkv_attention() like

        qk = q @ k
        print("q.shape=",q.shape,", k.shape=", k.shape)

it outputs

...
q.shape= torch.Size([5, 12, 1, 64]) , k.shape= torch.Size([5, 12, 64, 6])
q.shape= torch.Size([5, 12, 1, 64]) , k.shape= torch.Size([1, 12, 64, 1500])

What arguments did you use to get k like (7, 4, 4, 16)? Where is that 7 comes from?

ps: I only test this on mac cpu backend. I guess that 7 is used by GPU related code?

Copy link

@yuekaizhang yuekaizhang Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangchou Did you try to inference with batch_size > 1? I met this issue when I tried with both batch_size, beam_size > 1. The snippet codes above using batch_size 7, beam_size 10.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuekaizhang I don't even know batch_size option. And I cannot find it with whisper --help. Sorry.

# repeat text tensors by the group size, for beam search or best-of-n sampling
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)

# call the main sampling loop
Expand Down