Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Deprecate forced ids for v4.39 #29485

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 11 additions & 37 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from ...generation.configuration_utils import GenerationConfig
from ...generation.logits_process import (
ForceTokensLogitsProcessor,
LogitsProcessorList,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
Expand Down Expand Up @@ -537,11 +536,9 @@ def generate(
num_segment_frames=num_segment_frames,
kwargs=kwargs,
)
# TODO(Sanchit) - passing `decoder_input_ids` is deprecated. One should use `prompt_ids` instead
# This function should be be removed in v4.39
self._check_decoder_input_ids(
prompt_ids=prompt_ids, init_tokens=init_tokens, is_shortform=is_shortform, kwargs=kwargs
)
# passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
# where the input ids are handled explicitly by the generate method
self._check_decoder_input_ids(kwargs=kwargs)

# 3. Retrieve logits processors
begin_index = len(init_tokens)
Expand Down Expand Up @@ -1127,15 +1124,13 @@ def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
forced_decoder_ids = forced_decoder_ids[1:]
i += 1

# TODO(Sanchit): Let's make sure we don't allow incorrectly / weirdly formatted `forced_decoder_ids` after transformers v4.39
if len(forced_decoder_ids) > 0:
warnings.warn(
f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}. `forced_decoder_ids` will be passed as a logit processor, but note that this functionality has been deprecated and will throw an error in v4.39.",
FutureWarning,
raise ValueError(
f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}.",
)

# TODO(Sanchit): set generation_config.forced_decoder_ids to None for v4.39
generation_config.forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None
# from v4.39 the forced decoder ids are always None in favour of decoder input ids
generation_config.forced_decoder_ids = None

is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
if language is not None:
Expand Down Expand Up @@ -1280,20 +1275,12 @@ def detect_language(
return lang_ids

@staticmethod
def _check_decoder_input_ids(prompt_ids, init_tokens, is_shortform, kwargs):
def _check_decoder_input_ids(kwargs):
decoder_input_ids = kwargs.get("decoder_input_ids", None)
if prompt_ids is not None and decoder_input_ids is not None:
assistant_model = kwargs.get("assistant_model", None)
if decoder_input_ids is not None and assistant_model is not None:
raise ValueError(
f"Cannot pass both `prompt_ids`: {prompt_ids} and `decoder_input_ids`: {decoder_input_ids}. Passing `decoder_input_ids` is deprecated, consider not passing it."
)
elif decoder_input_ids is not None and not is_shortform:
raise ValueError(
f"Cannot pass both `decoder_input_ids`: {decoder_input_ids} for long-form generation. Consider passing `prompt_ids` instead."
)
elif decoder_input_ids is not None and is_shortform:
warnings.warn(
f"You have provided `decoder_input_ids` which will overwrite the `init_tokens` {init_tokens}. This might lead to unexpected behavior. Passing `decoder_input_ids` is deprecated and will be removed in v4.39. Consider passing `prompt_ids` instead.",
FutureWarning,
"Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead.",
)

@staticmethod
Expand Down Expand Up @@ -1434,19 +1421,6 @@ def _retrieve_logit_processors(self, generation_config, logits_processor, begin_
)
no_speech_detector.set_model(self)

if is_shortform and generation_config.forced_decoder_ids is not None:
forced_tokens_proc = ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)
# It's important that the `forced_tokens_proc` processor is appended after
# the suppress_tokens processor or else it might happen that all token logits are suppressed to -inf
# which would lead to unexpected behavior
# The better approach here is to NOT make use of the `forced_tokens_proc` for Whisper and instead
# initialize all of them as `decoder_input_ids`.
# TODO(Sanchit): Make sure to deprecate this in v4.39 as there will be no `forced_decoder_ids` anymore.
logits_processor = (
[forced_tokens_proc] if logits_processor is None else logits_processor + [forced_tokens_proc]
)
generation_config.forced_decoder_ids = None

return logits_processor

@staticmethod
Expand Down
Loading