From 6375d4bc9e88ceb4a0ad6ed44f9b0028c0e27a25 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Wed, 3 Jul 2024 13:38:38 +0100 Subject: [PATCH 01/11] [whisper] compile compatibility with long-form decoding --- .../models/whisper/generation_whisper.py | 48 ++++++++++++--- .../models/whisper/modeling_whisper.py | 4 ++ tests/models/whisper/test_modeling_whisper.py | 60 +++++++++++++++++++ 3 files changed, 104 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 2a205f9f9bc5..5909aeec7c4a 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -126,12 +126,24 @@ def _get_attr_from_logit_processors(logits_processor, logit_processor_class, att def _pad_to_max_length( - current_segments, pad_token_id, device, padding="right", bos_token_tensor=None, cut_off_length=None + current_segments, + pad_token_id, + device, + padding_side="right", + padding="longest", + bos_token_tensor=None, + cut_off_length=None, ): max_total_length = 0 sequences = [] - if padding not in ["right", "left"]: - raise ValueError(f"`padding` must be either 'right' or 'left', not {padding}") + + if padding_side not in ["right", "left"]: + raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}") + + if padding not in ["longest", "max_length"]: + raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}") + elif padding == "max_length" and cut_off_length is None: + raise ValueError("`cut_off_length` must be specified when `padding='max_length'`") for current_segment_list in current_segments: if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0: @@ -150,9 +162,10 @@ def _pad_to_max_length( else: sequences.append(torch.tensor([], device=device)) + max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length for i in range(len(current_segments)): pad_length = max_total_length - len(sequences[i]) - pad = (0, pad_length) if padding == "right" else (pad_length, 0) + pad = (0, pad_length) if padding_side == "right" else (pad_length, 0) sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id) sequences = torch.stack(sequences, dim=0) @@ -671,7 +684,7 @@ def generate( synced_gpus=synced_gpus, return_token_timestamps=return_token_timestamps, do_condition_on_prev_tokens=do_condition_on_prev_tokens, - is_shortform=is_shortform, + batch_size=batch_size, kwargs=kwargs, ) @@ -712,7 +725,7 @@ def generate( ) sequences = _pad_to_max_length( - final_segments, generation_config.pad_token_id, device=self.device, padding="right" + final_segments, generation_config.pad_token_id, device=self.device, padding_side="right" ) # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. @@ -774,7 +787,7 @@ def generate_with_fallback( synced_gpus, return_token_timestamps, do_condition_on_prev_tokens, - is_shortform, + batch_size, kwargs, ): kwargs = copy.copy(kwargs) @@ -798,6 +811,18 @@ def generate_with_fallback( for key in ["do_sample", "temperature", "num_beams"]: if key in generate_kwargs: del generate_kwargs[key] + + cur_bsz = segment_input.shape[0] + if generation_config.cache_implementation == "static" and cur_bsz < batch_size: + segment_input = F.pad(segment_input, (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0) + decoder_input_ids = F.pad( + decoder_input_ids, (0, 0, 0, batch_size - cur_bsz), value=generation_config.pad_token_id + ) + if generate_kwargs.get("decoder_attention_mask") is not None: + generate_kwargs["decoder_attention_mask"] = F.pad( + generate_kwargs["decoder_attention_mask"], (0, 0, 0, batch_size - cur_bsz), value=True + ) + seek_outputs = super().generate( segment_input, generation_config=generation_config, @@ -820,6 +845,10 @@ def generate_with_fallback( is_shortform=is_shortform, ) + if cur_bsz < batch_size: + seek_sequences = seek_sequences[:cur_bsz] + seek_outputs = seek_outputs[:cur_bsz] + # 6.7 Extract cut sequences from every sequence and check if fallback should be applied # Loop over each decoded audio individually as each decoding can be of a different length new_fallback_index_map = [] @@ -1613,11 +1642,14 @@ def _prepare_decoder_input_ids( one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long) prev_ids = prev_start_of_text * one_tensor[0] if prev_start_of_text is not None else None + padding = "max_length" if generation_config.cache_implementation == "static" else "longest" + prev_tokens = _pad_to_max_length( active_segments, generation_config.pad_token_id, device=device, - padding="left", + padding_side="left", + padding=padding, bos_token_tensor=prev_ids, cut_off_length=cut_off_length, ) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 8785d5681f73..f2aca32024b8 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1845,6 +1845,10 @@ def prepare_inputs_for_generation( elif use_cache: cache_position = cache_position[-decoder_input_ids.shape[1] :] + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + decoder_input_ids = decoder_input_ids.contiguous() + return { "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 00d5189de47f..1d7a5edcd185 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -3386,6 +3386,66 @@ def test_tiny_static_generation(self): # assert re-ordered generations match those from eager assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all() + @slow + def test_tiny_static_generation_long_form(self): + import torch._dynamo.config + + # only permit 4 compilations: 2 prefill steps and 2 decoding steps (1 for each of conditioned/not conditioned) + torch._dynamo.config.cache_size_limit = 4 + + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + model.to(torch_device) + + dataset = load_dataset("distil-whisper/meanwhile", "default")["test"] + dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) + input_speech = [audio["array"] for audio in dataset[2:4]["audio"]] + + inputs = processor( + input_speech, + return_tensors="pt", + padding="longest", + truncation=False, + return_attention_mask=True, + sampling_rate=16_000, + ) + inputs = inputs.to(torch_device) + + gen_kwargs = { + "return_timestamps": True, + "no_speech_threshold": 0.6, + "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), + "compression_ratio_threshold": 1.35, + "condition_on_prev_tokens": True, # conditioning on prev tokens introduces a recompile on the second decoding step + "logprob_threshold": -1.0, + "num_beams": 1, + } + + set_seed(42) + eager_generated_ids = model.generate(**inputs, **gen_kwargs) + + # compile the forward pass and assert equivalence + model.generation_config.cache_implementation = "static" + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + + set_seed(42) + static_generated_ids = model.generate(**inputs, **gen_kwargs) + assert (eager_generated_ids == static_generated_ids).all() + + # check the compiled graph can be re-used and that the cache is correctly reset + # reverse the ordering of the input features + input_features = inputs.input_features + permutation_idx = ( + torch.arange(input_features.shape[0], 0, step=-1, dtype=torch.long, device=input_features.device) - 1 + ) + input_features = input_features[permutation_idx, ...] + attention_mask = inputs.attention_mask[permutation_idx, ...] + + set_seed(42) + static_generated_ids = model.generate(input_features, attention_mask=attention_mask, **gen_kwargs) + # assert re-ordered generations match those from eager + assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all() + def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): if head_mask is None: From 1fa363762dd7caed39f6b30678338b5e35d1e395 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Wed, 3 Jul 2024 13:46:13 +0100 Subject: [PATCH 02/11] clarify comment --- tests/models/whisper/test_modeling_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 1d7a5edcd185..f43f29f56510 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -3416,7 +3416,7 @@ def test_tiny_static_generation_long_form(self): "no_speech_threshold": 0.6, "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), "compression_ratio_threshold": 1.35, - "condition_on_prev_tokens": True, # conditioning on prev tokens introduces a recompile on the second decoding step + "condition_on_prev_tokens": True, # conditioning on prev tokens introduces a recompile on the second time step "logprob_threshold": -1.0, "num_beams": 1, } From 8507ee01f93dc43d4634bc4654222c0939fa4e71 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Wed, 24 Jul 2024 14:10:58 +0800 Subject: [PATCH 03/11] fix after rebase --- src/transformers/models/whisper/generation_whisper.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 5909aeec7c4a..1c662fb9169a 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -684,6 +684,7 @@ def generate( synced_gpus=synced_gpus, return_token_timestamps=return_token_timestamps, do_condition_on_prev_tokens=do_condition_on_prev_tokens, + is_shortform=is_shortform, batch_size=batch_size, kwargs=kwargs, ) @@ -787,6 +788,7 @@ def generate_with_fallback( synced_gpus, return_token_timestamps, do_condition_on_prev_tokens, + is_shortform, batch_size, kwargs, ): From aa7577c356f1417aeedc5a56a81b717cc440f2dd Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Wed, 24 Jul 2024 15:36:22 +0800 Subject: [PATCH 04/11] finalise --- .../models/whisper/generation_whisper.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 1c662fb9169a..612140673e02 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -956,17 +956,21 @@ def split_by_batch_index(values, key, batch_idx, is_shortform): if not is_shortform: # we don't save `past_key_values` as this is too costly for longform return None + elif isinstance(values, EncoderDecoderCache): + all_past_key_values = [] + for layer_idx in range(self.config.decoder_layers): + layer_past_key_values = [] + for cache_cls in [values.self_attention_cache, values.cross_attention_cache]: + for v in [cache_cls.key_cache, cache_cls.value_cache]: + layer_past_key_values.append(v[layer_idx][None].cpu()) + all_past_key_values.append(tuple(layer_past_key_values)) + return tuple(all_past_key_values) else: return tuple(tuple(w[batch_idx][None].cpu() for w in values[v]) for v in range(len(values))) return values[batch_idx].cpu() sequence_tokens = seek_outputs["sequences"] - - if hasattr(seek_outputs, "past_key_values") and seek_outputs.past_key_values is not None: - if isinstance(seek_outputs["past_key_values"], EncoderDecoderCache): - seek_outputs.past_key_values = seek_outputs.past_key_values.to_legacy_cache() - seek_outputs = [ {k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()} for i in range(sequence_tokens.shape[0]) From 2828f51d96e4d1d2ee163e3e3c3d4321b935c7b7 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 1 Aug 2024 10:07:53 +0800 Subject: [PATCH 05/11] fix bsz --- src/transformers/models/whisper/generation_whisper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 612140673e02..61e0b1984ae0 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -814,7 +814,7 @@ def generate_with_fallback( if key in generate_kwargs: del generate_kwargs[key] - cur_bsz = segment_input.shape[0] + cur_bsz = decoder_input_ids.shape[0] if generation_config.cache_implementation == "static" and cur_bsz < batch_size: segment_input = F.pad(segment_input, (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0) decoder_input_ids = F.pad( @@ -824,6 +824,8 @@ def generate_with_fallback( generate_kwargs["decoder_attention_mask"] = F.pad( generate_kwargs["decoder_attention_mask"], (0, 0, 0, batch_size - cur_bsz), value=True ) + if generate_kwargs.get("encoder_outputs") is not None: + generate_kwargs["encoder_outputs"] = F.pad(generate_kwargs["encoder_outputs"], (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0) seek_outputs = super().generate( segment_input, From 4150700e3f3d4b2e25a217a56148493abfce398f Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 1 Aug 2024 10:36:41 +0800 Subject: [PATCH 06/11] fix cache split --- src/transformers/models/whisper/generation_whisper.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 61e0b1984ae0..bdc412277f6e 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -964,11 +964,17 @@ def split_by_batch_index(values, key, batch_idx, is_shortform): layer_past_key_values = [] for cache_cls in [values.self_attention_cache, values.cross_attention_cache]: for v in [cache_cls.key_cache, cache_cls.value_cache]: - layer_past_key_values.append(v[layer_idx][None].cpu()) + layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu()) all_past_key_values.append(tuple(layer_past_key_values)) return tuple(all_past_key_values) else: - return tuple(tuple(w[batch_idx][None].cpu() for w in values[v]) for v in range(len(values))) + all_past_key_values = [] + for v in range(len(values)): + layer_past_key_values = [] + for w in values[v]: + layer_past_key_values.append(w[batch_idx][None].cpu()) + all_past_key_values.append(tuple(layer_past_key_values)) + return tuple(all_past_key_values) return values[batch_idx].cpu() From 4f696938cf25d3f11fe02550d3393bb8b5e6ccf2 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 1 Aug 2024 10:43:42 +0800 Subject: [PATCH 07/11] remove contiguous --- src/transformers/models/whisper/modeling_whisper.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index f2aca32024b8..8785d5681f73 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1845,10 +1845,6 @@ def prepare_inputs_for_generation( elif use_cache: cache_position = cache_position[-decoder_input_ids.shape[1] :] - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - decoder_input_ids = decoder_input_ids.contiguous() - return { "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, From 0b6acfb0076129b61d38d49466139f1afaf1d19b Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 1 Aug 2024 10:44:18 +0800 Subject: [PATCH 08/11] style --- src/transformers/models/whisper/generation_whisper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index bdc412277f6e..3c4b5795e461 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -825,7 +825,9 @@ def generate_with_fallback( generate_kwargs["decoder_attention_mask"], (0, 0, 0, batch_size - cur_bsz), value=True ) if generate_kwargs.get("encoder_outputs") is not None: - generate_kwargs["encoder_outputs"] = F.pad(generate_kwargs["encoder_outputs"], (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0) + generate_kwargs["encoder_outputs"] = F.pad( + generate_kwargs["encoder_outputs"], (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0 + ) seek_outputs = super().generate( segment_input, From 719a80ae5ab64c9daa3025974bcc25afc107efe6 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 1 Aug 2024 10:56:33 +0800 Subject: [PATCH 09/11] finish --- src/transformers/models/whisper/modeling_whisper.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 8785d5681f73..f2aca32024b8 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1845,6 +1845,10 @@ def prepare_inputs_for_generation( elif use_cache: cache_position = cache_position[-decoder_input_ids.shape[1] :] + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + decoder_input_ids = decoder_input_ids.contiguous() + return { "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, From 21002ace80194676a44276f9ad1da7793db5354f Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 1 Aug 2024 16:47:09 +0800 Subject: [PATCH 10/11] update doc --- docs/source/en/model_doc/whisper.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/whisper.md b/docs/source/en/model_doc/whisper.md index 0565bd5aae11..6c83407f7666 100644 --- a/docs/source/en/model_doc/whisper.md +++ b/docs/source/en/model_doc/whisper.md @@ -72,7 +72,7 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' ``` -Whisper is compatible with the following optimisations: +Whisper is compatible with the following optimisations for both short and long-form generation: - [PyTorch Scaled Dot Product Attention (SDPA)](../perf_infer_gpu_one#pytorch-scaled-dot-product-attention): flash attention and memory-efficient attention kernels. Enabled by default for `torch>=2.1.1`. - [Flash Attention 2](../perf_infer_gpu_one#flashattention-2): improved implementation of flash attention through better parallelism and work partitioning. - [torch.compile](../llm_optims#static-kv-cache-and-torchcompile): JIT-compile the forward pass to dispatch to efficient fused kernels. @@ -101,7 +101,8 @@ As an example, the following codesnippet enables SDPA and `torch.compile` for up ... ).input_features >>> # Compile the forward pass ->>> _ = model.generate(input_features) +>>> for _ in range(2): +>>> model.generate(input_features) >>> # Generate token ids using compiled graph (fast!) >>> predicted_ids = model.generate(input_features) From ef47d06e6a409e44da6e4702cf0bc78f854f7287 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 1 Aug 2024 17:27:28 +0800 Subject: [PATCH 11/11] prevent cuda graph trace --- .../models/whisper/modeling_whisper.py | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index f2aca32024b8..49f305fecf51 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1835,8 +1835,10 @@ def prepare_inputs_for_generation( decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]: + if decoder_position_ids is not None: decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + decoder_position_ids = decoder_position_ids.clone(memory_format=torch.contiguous_format) if cache_position is None: cache_position = torch.arange( @@ -1849,6 +1851,32 @@ def prepare_inputs_for_generation( # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 decoder_input_ids = decoder_input_ids.contiguous() + if ( + isinstance(past_key_values, EncoderDecoderCache) + and ( + isinstance(past_key_values.self_attention_cache, StaticCache) + or isinstance(past_key_values.cross_attention_cache, StaticCache) + ) + and decoder_attention_mask is not None + and decoder_attention_mask.ndim == 2 + ): + batch_size, sequence_length = decoder_input_ids.shape + device = decoder_input_ids.device + + dtype = self.proj_out.weight.dtype + min_dtype = torch.finfo(dtype).min + + decoder_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + decoder_attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.self_attention_cache.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + return { "encoder_outputs": encoder_outputs, "past_key_values": past_key_values,