From 07e7db32332668c9d1116cb287f6e8fe82e3118a Mon Sep 17 00:00:00 2001 From: kamilakesbi Date: Fri, 24 May 2024 11:13:44 +0200 Subject: [PATCH] add is_shortform conditions --- .../models/whisper/generation_whisper.py | 42 ++++++++++++------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 3808e1d1d2aa..5a03d4cd82b6 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -122,8 +122,7 @@ def _get_attr_from_logit_processors(logits_processor, logit_processor_class, att return None -def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_token_id=None, cut_off_length=None): - +def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_token_tensor=None, cut_off_length=None): max_total_length = 0 sequences = [] if padding not in ["right", "left"]: @@ -136,14 +135,12 @@ def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_toke if cut_off_length is not None: sequence = sequence[-cut_off_length:] - if bos_token_id is not None: - bos_token_tensor = torch.tensor([bos_token_id]).to(sequence.device) + if bos_token_tensor is not None: sequence = torch.cat([bos_token_tensor, sequence]) sequences.append(sequence) max_total_length = max(max_total_length, len(sequences[-1])) - elif bos_token_id is not None: - bos_token_tensor = torch.tensor([bos_token_id]).to(sequence.device) + elif bos_token_tensor is not None: sequences.append(bos_token_tensor) else: sequences.append(torch.tensor([])) @@ -611,7 +608,11 @@ def generate( condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config ) - timestamp_begin = generation_config.no_timestamps_token_id + 1 + if not is_shortform: + timestamp_begin = generation_config.no_timestamps_token_id + 1 + else: + timestamp_begin = None + temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature temperature = temperatures[0] batch_size = input_features.shape[0] @@ -658,9 +659,13 @@ def generate( ) # 6.5 prepare decoder input ids - suppress_tokens = _get_attr_from_logit_processors( - logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens" - ) + if not is_shortform: + suppress_tokens = _get_attr_from_logit_processors( + logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens" + ) + else: + suppress_tokens = None + decoder_input_ids, kwargs = self._prepare_decoder_input_ids( cur_bsz=cur_bsz, init_tokens=init_tokens, @@ -683,9 +688,10 @@ def generate( ) # 6.7 Set current `begin_index` for all logit processors - for proc in logits_processor: - if hasattr(proc, "set_begin_index"): - proc.set_begin_index(decoder_input_ids.shape[-1]) + if not is_shortform: + for proc in logits_processor: + if hasattr(proc, "set_begin_index"): + proc.set_begin_index(decoder_input_ids.shape[-1]) # 6.8 Run generate with fallback seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback( @@ -749,7 +755,7 @@ def generate( # add decoder_input_ids tokens: sequences = torch.cat([decoder_input_ids, sequences], dim=-1) # add eos token: - sequences = torch.cat([sequences, torch.full((2,1,), generation_config.eos_token_id).to(sequences.device)], dim=-1) + sequences = torch.cat([sequences, torch.full((sequences.shape[0],1,), generation_config.eos_token_id).to(sequences.device)], dim=-1) if return_token_timestamps: outputs = {} outputs['sequences'] = sequences @@ -882,7 +888,7 @@ def generate_with_fallback( # if no sequence needs to be run with temperature fallback, we're finished if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1: seek_sequences = seek_sequence_list - seek_outputs = seek_outputs_list + seek_outputs = seek_outputs_list break # if we're still in the loop, make sure that decoder_input_ids and segment inputs are tensors @@ -1617,7 +1623,11 @@ def _retrieve_segment( ): # find the predicted "end of segment" predictions of Whisper # "end of segment" predictions occur whenever Whisper predicts a timestamp token - timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin) + if timestamp_begin is not None: + timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin) + else: + timestamp_tokens: torch.Tensor = torch.full((seek_sequence.shape[0],), False, dtype=torch.bool).to(seek_sequence.device) + single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] timestamp_segment_indices.add_(1)