Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: fix logits processors doctests #29718

Merged
merged 2 commits into from
Apr 2, 2024

Conversation

gante
Copy link
Member

@gante gante commented Mar 18, 2024

What does this PR do?

The doctests got stale 👀 (related PR to prevent this from happening again: #29716)

There are 2 main categories of fixes:

  1. Fixes where there is randomness involved: updates seed and potentially the output of the "bad" example. I can reproduce the existing doctest results if I go back to an older version (like v4.35), but I don't think it's worth diving through to find the root cause, as many harmless things can change the output of sampling;
  2. Whisper fixes (cc @sanchit-gandhi )

All tests are passing after these changes (pytest --doctest-modules src/transformers/generation/logits_process.py -vv)

@gante
Copy link
Member Author

gante commented Mar 18, 2024

cc @zucchini-nlp, to rebase your PRs after this gets merged :)

>>> def prefix_allowed_tokens_fn(batch_id, input_ids):
... '''
... Attempts to generate 'Bob Marley' when 'Bob' is detected.
... In this case, `batch_id` is not used, but you can set rules for each batch member.
... '''
... if input_ids[-1] == entity[0]:
... return entity[1]
... return [entity[1].item()]
Copy link
Member Author

@gante gante Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefix_allowed_tokens_fn should be a Callable[[int, torch.Tensor], List[int]], as explained in the docs

@@ -1604,13 +1610,13 @@ class LogitNormalization(LogitsProcessor, LogitsWarper):
>>> # By default, the scores are not normalized -- the sum of their exponentials is NOT a normalized probability
>>> # distribution, summing to 1
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
>>> print(torch.sum(torch.exp(outputs.scores[-1])))
tensor(816.3250)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This value was sensible to numerical fluctuations across versions, and this exact value was not relevant for the test. The main point is that it is not approximately 1.0 :)

@@ -1641,7 +1647,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
>>> # Whisper has `begin_suppress_tokens` set by default (= `[220, 50256]`). 50256 is the EOS token, so this means
>>> # it can't generate and EOS token in the first iteration, but it can in the others.
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
>>> print(outputs.scores[1][0, 50256]) # 1 (and not 0) is the first freely generated token
>>> print(outputs.scores[0][0, 50256])
Copy link
Member Author

@gante gante Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whisper processor changes: @sanchit-gandhi let me know if they make sense, according to recent changes in Whisper

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me - thanks for the updated @gante!

@@ -1714,36 +1720,6 @@ class ForceTokensLogitsProcessor(LogitsProcessor):
indices that will be forced before generation. The processor will set their log probs to `inf` so that they are
sampled at their corresponding index. Originally created for
[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).

Examples:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This processor is going to be removed in v4.40, so I didn't want to spend time fixing the test :D

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cheeky :D

else:
generation_config = copy.deepcopy(generation_config)
# 1. prepare generation config
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function from the main generate body (_prepare_generation_config) pulls generation parameterization from kwargs into generation_config.

Some Whisper-based doctests were incorrect without this functionality.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this fix!

I have a few questions about the changes, in particular why we need to change the seed

else:
generation_config = copy.deepcopy(generation_config)
# 1. prepare generation config
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lines above imply there's a self.generation_config which should be used if generation_config is None

Copy link
Member Author

@gante gante Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._prepare_generation_config() does precisely that:

generation_config = self.generation_config

It is a more complex version of this original if/else that preserves additional backward (and forward!) compatibility features of generate :)

>>> set_seed(0)
>>> set_seed(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change the seed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The seed is changed because the sample output is changed (more on that below), and a new seed was selected to illustrate the point of the example 🤗 I wanted a seed that produced a bad output in the unparameterized call and a good output in the parameterized call. Bear in mind that the model used in the examples is very small, and thus noisy with sampling.

We need to change the seed because the output of sampling has changed. There are many innocuous changes that can cause this: tiny numerical differences due to different versions, tiny numerical differences due to reordering of operations, corrections in the architecture code, different RNG behavior in torch (unlikely), and so on. As I've written in the PR header, I don't think it's worth our time finding the exact cause. The results in most other sampling tests are unchanged, there are many innocuous changes that can cause this, and it may be time-consuming to pin the cause.

src/transformers/generation/logits_process.py Show resolved Hide resolved
src/transformers/generation/logits_process.py Show resolved Hide resolved
src/transformers/generation/logits_process.py Show resolved Hide resolved
src/transformers/generation/logits_process.py Show resolved Hide resolved
Comment on lines -1717 to -1746

Examples:
```python
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset

>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")

>>> # This Whisper model forces the generation to start with `50362` at the first position by default, i.e.
>>> # `"forced_decoder_ids": [[1, 50362]]`. This means all other tokens are masked out.
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
>>> print(
... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
... )
True
>>> print(outputs.scores[0][0, 50362])
tensor(0.)

>>> # If we disable `forced_decoder_ids`, we stop seeing that effect
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, forced_decoder_ids=None)
>>> print(
... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
... )
False
>>> print(outputs.scores[0][0, 50362])
tensor(19.3140)
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove the example here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This processor is going to be removed in v4.40, so I didn't want to spend time fixing the test :D

:)

Comment on lines +1613 to +1614
>>> print(torch.allclose(torch.sum(torch.exp(outputs.scores[-1])), torch.Tensor((1.000,)), rtol=1e-4))
False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous output was more informative imo - there's infinitely many ways to not be close to 1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but it is beyond the scope of the example -- the key point here is adding the flag normalizes the probability distribution.

Testing against the exact number caused the test to fail. In fact, if we run this test on different hardware (local compute vs DGX), we get a slightly different number. We could work around it with torch.allclose, but I don't think it adds value to the test :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK 👍

@@ -1641,7 +1647,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
>>> # Whisper has `begin_suppress_tokens` set by default (= `[220, 50256]`). 50256 is the EOS token, so this means
>>> # it can't generate and EOS token in the first iteration, but it can in the others.
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
>>> print(outputs.scores[1][0, 50256]) # 1 (and not 0) is the first freely generated token
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of interest - what changed here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the indexing of first freely decoded token changed recently in Whisper, but I'd like to have @sanchit-gandhi confirming the correctness of these changes :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a possible BC issue :/

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

Happy with the changes - only concern is the difference in the whisper processor @sanchit-gandhi can you confirm this?

Comment on lines +1613 to +1614
>>> print(torch.allclose(torch.sum(torch.exp(outputs.scores[-1])), torch.Tensor((1.000,)), rtol=1e-4))
False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK 👍

@@ -1714,36 +1720,6 @@ class ForceTokensLogitsProcessor(LogitsProcessor):
indices that will be forced before generation. The processor will set their log probs to `inf` so that they are
sampled at their corresponding index. Originally created for
[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).

Examples:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cheeky :D

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fixes @gante!

@gante gante merged commit 5080ab1 into huggingface:main Apr 2, 2024
19 checks passed
@gante gante deleted the fix_logits_doctests branch April 2, 2024 16:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants