Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[whisper] compile compatibility with long-form decoding #31772

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 60 additions & 12 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,24 @@ def _get_attr_from_logit_processors(logits_processor, logit_processor_class, att


def _pad_to_max_length(
current_segments, pad_token_id, device, padding="right", bos_token_tensor=None, cut_off_length=None
current_segments,
pad_token_id,
device,
padding_side="right",
padding="longest",
bos_token_tensor=None,
cut_off_length=None,
):
max_total_length = 0
sequences = []
if padding not in ["right", "left"]:
raise ValueError(f"`padding` must be either 'right' or 'left', not {padding}")

if padding_side not in ["right", "left"]:
raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")

if padding not in ["longest", "max_length"]:
raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
elif padding == "max_length" and cut_off_length is None:
raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")

for current_segment_list in current_segments:
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
Expand All @@ -150,9 +162,10 @@ def _pad_to_max_length(
else:
sequences.append(torch.tensor([], device=device))

max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
for i in range(len(current_segments)):
pad_length = max_total_length - len(sequences[i])
pad = (0, pad_length) if padding == "right" else (pad_length, 0)
pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)

sequences = torch.stack(sequences, dim=0)
Expand Down Expand Up @@ -672,6 +685,7 @@ def generate(
return_token_timestamps=return_token_timestamps,
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
is_shortform=is_shortform,
batch_size=batch_size,
kwargs=kwargs,
)

Expand Down Expand Up @@ -712,7 +726,7 @@ def generate(
)

sequences = _pad_to_max_length(
final_segments, generation_config.pad_token_id, device=self.device, padding="right"
final_segments, generation_config.pad_token_id, device=self.device, padding_side="right"
)

# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
Expand Down Expand Up @@ -775,6 +789,7 @@ def generate_with_fallback(
return_token_timestamps,
do_condition_on_prev_tokens,
is_shortform,
batch_size,
kwargs,
):
kwargs = copy.copy(kwargs)
Expand All @@ -798,6 +813,22 @@ def generate_with_fallback(
for key in ["do_sample", "temperature", "num_beams"]:
if key in generate_kwargs:
del generate_kwargs[key]

cur_bsz = decoder_input_ids.shape[0]
if generation_config.cache_implementation == "static" and cur_bsz < batch_size:
segment_input = F.pad(segment_input, (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0)
decoder_input_ids = F.pad(
decoder_input_ids, (0, 0, 0, batch_size - cur_bsz), value=generation_config.pad_token_id
)
if generate_kwargs.get("decoder_attention_mask") is not None:
generate_kwargs["decoder_attention_mask"] = F.pad(
generate_kwargs["decoder_attention_mask"], (0, 0, 0, batch_size - cur_bsz), value=True
)
if generate_kwargs.get("encoder_outputs") is not None:
generate_kwargs["encoder_outputs"] = F.pad(
generate_kwargs["encoder_outputs"], (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0
)

seek_outputs = super().generate(
segment_input,
generation_config=generation_config,
Expand All @@ -820,6 +851,10 @@ def generate_with_fallback(
is_shortform=is_shortform,
)

if cur_bsz < batch_size:
seek_sequences = seek_sequences[:cur_bsz]
seek_outputs = seek_outputs[:cur_bsz]

# 6.7 Extract cut sequences from every sequence and check if fallback should be applied
# Loop over each decoded audio individually as each decoding can be of a different length
new_fallback_index_map = []
Expand Down Expand Up @@ -925,17 +960,27 @@ def split_by_batch_index(values, key, batch_idx, is_shortform):
if not is_shortform:
# we don't save `past_key_values` as this is too costly for longform
return None
elif isinstance(values, EncoderDecoderCache):
all_past_key_values = []
for layer_idx in range(self.config.decoder_layers):
layer_past_key_values = []
for cache_cls in [values.self_attention_cache, values.cross_attention_cache]:
for v in [cache_cls.key_cache, cache_cls.value_cache]:
layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu())
all_past_key_values.append(tuple(layer_past_key_values))
return tuple(all_past_key_values)
else:
return tuple(tuple(w[batch_idx][None].cpu() for w in values[v]) for v in range(len(values)))
all_past_key_values = []
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a re-factor to split the single line list iteration over several lines, in order to be more verbose

for v in range(len(values)):
layer_past_key_values = []
for w in values[v]:
layer_past_key_values.append(w[batch_idx][None].cpu())
all_past_key_values.append(tuple(layer_past_key_values))
return tuple(all_past_key_values)

return values[batch_idx].cpu()

sequence_tokens = seek_outputs["sequences"]

if hasattr(seek_outputs, "past_key_values") and seek_outputs.past_key_values is not None:
if isinstance(seek_outputs["past_key_values"], EncoderDecoderCache):
seek_outputs.past_key_values = seek_outputs.past_key_values.to_legacy_cache()

seek_outputs = [
{k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()}
for i in range(sequence_tokens.shape[0])
Expand Down Expand Up @@ -1613,11 +1658,14 @@ def _prepare_decoder_input_ids(
one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long)
prev_ids = prev_start_of_text * one_tensor[0] if prev_start_of_text is not None else None

padding = "max_length" if generation_config.cache_implementation == "static" else "longest"

prev_tokens = _pad_to_max_length(
active_segments,
generation_config.pad_token_id,
device=device,
padding="left",
padding_side="left",
padding=padding,
bos_token_tensor=prev_ids,
cut_off_length=cut_off_length,
)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,6 +1845,10 @@ def prepare_inputs_for_generation(
elif use_cache:
cache_position = cache_position[-decoder_input_ids.shape[1] :]

# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
decoder_input_ids = decoder_input_ids.contiguous()

return {
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
Expand Down
60 changes: 60 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3386,6 +3386,66 @@ def test_tiny_static_generation(self):
# assert re-ordered generations match those from eager
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()

@slow
def test_tiny_static_generation_long_form(self):
import torch._dynamo.config

# only permit 4 compilations: 2 prefill steps and 2 decoding steps (1 for each of conditioned/not conditioned)
torch._dynamo.config.cache_size_limit = 4

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.to(torch_device)

dataset = load_dataset("distil-whisper/meanwhile", "default")["test"]
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
input_speech = [audio["array"] for audio in dataset[2:4]["audio"]]

inputs = processor(
input_speech,
return_tensors="pt",
padding="longest",
truncation=False,
return_attention_mask=True,
sampling_rate=16_000,
)
inputs = inputs.to(torch_device)

gen_kwargs = {
"return_timestamps": True,
"no_speech_threshold": 0.6,
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
"compression_ratio_threshold": 1.35,
"condition_on_prev_tokens": True, # conditioning on prev tokens introduces a recompile on the second time step
"logprob_threshold": -1.0,
"num_beams": 1,
}

set_seed(42)
eager_generated_ids = model.generate(**inputs, **gen_kwargs)

# compile the forward pass and assert equivalence
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

set_seed(42)
static_generated_ids = model.generate(**inputs, **gen_kwargs)
assert (eager_generated_ids == static_generated_ids).all()

# check the compiled graph can be re-used and that the cache is correctly reset
# reverse the ordering of the input features
input_features = inputs.input_features
permutation_idx = (
torch.arange(input_features.shape[0], 0, step=-1, dtype=torch.long, device=input_features.device) - 1
)
input_features = input_features[permutation_idx, ...]
attention_mask = inputs.attention_mask[permutation_idx, ...]

set_seed(42)
static_generated_ids = model.generate(input_features, attention_mask=attention_mask, **gen_kwargs)
# assert re-ordered generations match those from eager
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()


def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None:
Expand Down
Loading