Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: fix logits processors doctests #29718

Merged
merged 2 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 26 additions & 50 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ class TemperatureLogitsWarper(LogitsWarper):
>>> generate_kwargs = {"max_new_tokens": 10, "do_sample": True, "temperature": 1.0, "num_return_sequences": 2}
>>> outputs = model.generate(**inputs, **generate_kwargs)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['Hugging Face Company is a joint venture between GEO Group, one of',
'Hugging Face Company is not an exact science – but what we believe does']
['Hugging Face Company is one of these companies that is going to take a',
"Hugging Face Company is a brand created by Brian A. O'Neil"]

>>> # However, with temperature close to 0, it approximates greedy decoding strategies (invariant)
>>> generate_kwargs["temperature"] = 0.0001
Expand Down Expand Up @@ -414,7 +414,7 @@ class TopPLogitsWarper(LogitsWarper):
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

>>> set_seed(0)
>>> set_seed(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change the seed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The seed is changed because the sample output is changed (more on that below), and a new seed was selected to illustrate the point of the example 🤗 I wanted a seed that produced a bad output in the unparameterized call and a good output in the parameterized call. Bear in mind that the model used in the examples is very small, and thus noisy with sampling.

We need to change the seed because the output of sampling has changed. There are many innocuous changes that can cause this: tiny numerical differences due to different versions, tiny numerical differences due to reordering of operations, corrections in the architecture code, different RNG behavior in torch (unlikely), and so on. As I've written in the PR header, I don't think it's worth our time finding the exact cause. The results in most other sampling tests are unchanged, there are many innocuous changes that can cause this, and it may be time-consuming to pin the cause.

>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")

Expand All @@ -423,7 +423,9 @@ class TopPLogitsWarper(LogitsWarper):
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2
A sequence: 1, 2, 3 | < 4 (left-hand pointer) ;
<BLANKLINE>
<BLANKLINE>
amyeroberts marked this conversation as resolved.
Show resolved Hide resolved

>>> # With `top_p` sampling, the output gets restricted to high-probability tokens.
>>> # Pro tip: In practice, LLMs use `top_p` in the 0.9-0.95 range.
Expand Down Expand Up @@ -478,7 +480,7 @@ class TopKLogitsWarper(LogitsWarper):
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

>>> set_seed(0)
>>> set_seed(1)
amyeroberts marked this conversation as resolved.
Show resolved Hide resolved
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")

Expand All @@ -487,7 +489,7 @@ class TopKLogitsWarper(LogitsWarper):
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: A, B, C, D, G, H, I. A, M
A sequence: A, B, C, D, E — S — O, P — R

>>> # With `top_k` sampling, the output gets restricted the k most likely tokens.
>>> # Pro tip: In practice, LLMs use `top_k` in the 5-50 range.
Expand Down Expand Up @@ -619,7 +621,7 @@ class EpsilonLogitsWarper(LogitsWarper):
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

>>> set_seed(0)
>>> set_seed(1)
amyeroberts marked this conversation as resolved.
Show resolved Hide resolved
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")

Expand All @@ -628,7 +630,9 @@ class EpsilonLogitsWarper(LogitsWarper):
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2
A sequence: 1, 2, 3 | < 4 (left-hand pointer) ;
<BLANKLINE>
<BLANKLINE>

>>> # With epsilon sampling, the output gets restricted to high-probability tokens. Note that this is similar to
>>> # Top P sampling, which restricts tokens based on their cumulative probability.
Expand Down Expand Up @@ -696,7 +700,7 @@ class EtaLogitsWarper(LogitsWarper):
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

>>> set_seed(0)
>>> set_seed(1)
amyeroberts marked this conversation as resolved.
Show resolved Hide resolved
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")

Expand All @@ -705,7 +709,9 @@ class EtaLogitsWarper(LogitsWarper):
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2
A sequence: 1, 2, 3 | < 4 (left-hand pointer) ;
<BLANKLINE>
<BLANKLINE>

>>> # With eta sampling, the output gets restricted to high-probability tokens. You can see it as a dynamic form of
>>> # epsilon sampling that adapts its cutoff probability based on the entropy (high entropy = lower cutoff).
Expand Down Expand Up @@ -1204,16 +1210,16 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor):

>>> # We can contrain it with `prefix_allowed_tokens_fn` to force a certain behavior based on a prefix.
>>> # For instance, we can force an entire entity to be generated when its beginning is detected.
>>> entity = tokenizer(" Bob Marley", return_tensors="pt").input_ids[0] # 3 tokens
>>> entity = tokenizer(" Bob Marley", return_tensors="pt").input_ids[0] # 3 tokens
>>> def prefix_allowed_tokens_fn(batch_id, input_ids):
... '''
... Attempts to generate 'Bob Marley' when 'Bob' is detected.
... In this case, `batch_id` is not used, but you can set rules for each batch member.
... '''
... if input_ids[-1] == entity[0]:
... return entity[1]
... return [entity[1].item()]
Copy link
Member Author

@gante gante Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefix_allowed_tokens_fn should be a Callable[[int, torch.Tensor], List[int]], as explained in the docs

... elif input_ids[-2] == entity[0] and input_ids[-1] == entity[1]:
... return entity[2]
... return [entity[2].item()]
... return list(range(tokenizer.vocab_size)) # If no match, allow all tokens

>>> outputs = model.generate(**inputs, max_new_tokens=5, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn)
Expand Down Expand Up @@ -1604,13 +1610,13 @@ class LogitNormalization(LogitsProcessor, LogitsWarper):
>>> # By default, the scores are not normalized -- the sum of their exponentials is NOT a normalized probability
>>> # distribution, summing to 1
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
>>> print(torch.sum(torch.exp(outputs.scores[-1])))
tensor(816.3250)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This value was sensible to numerical fluctuations across versions, and this exact value was not relevant for the test. The main point is that it is not approximately 1.0 :)

>>> print(torch.allclose(torch.sum(torch.exp(outputs.scores[-1])), torch.Tensor((1.000,)), rtol=1e-4))
False
Comment on lines +1613 to +1614
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous output was more informative imo - there's infinitely many ways to not be close to 1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but it is beyond the scope of the example -- the key point here is adding the flag normalizes the probability distribution.

Testing against the exact number caused the test to fail. In fact, if we run this test on different hardware (local compute vs DGX), we get a slightly different number. We could work around it with torch.allclose, but I don't think it adds value to the test :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK 👍


>>> # Normalizing them may have a positive impact on beam methods, or when using the scores on your application
>>> outputs = model.generate(**inputs, renormalize_logits=True, return_dict_in_generate=True, output_scores=True)
>>> print(torch.sum(torch.exp(outputs.scores[-1])))
tensor(1.0000)
>>> print(torch.allclose(torch.sum(torch.exp(outputs.scores[-1])), torch.Tensor((1.000,)), rtol=1e-4))
True
```
"""

Expand Down Expand Up @@ -1641,7 +1647,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
>>> # Whisper has `begin_suppress_tokens` set by default (= `[220, 50256]`). 50256 is the EOS token, so this means
>>> # it can't generate and EOS token in the first iteration, but it can in the others.
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
>>> print(outputs.scores[1][0, 50256]) # 1 (and not 0) is the first freely generated token
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of interest - what changed here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the indexing of first freely decoded token changed recently in Whisper, but I'd like to have @sanchit-gandhi confirming the correctness of these changes :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a possible BC issue :/

>>> print(outputs.scores[0][0, 50256])
Copy link
Member Author

@gante gante Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whisper processor changes: @sanchit-gandhi let me know if they make sense, according to recent changes in Whisper

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me - thanks for the updated @gante!

tensor(-inf)
>>> print(outputs.scores[-1][0, 50256]) # in other places we can see some probability mass for EOS
tensor(29.9010)
Expand All @@ -1650,7 +1656,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
>>> outputs = model.generate(
... **inputs, return_dict_in_generate=True, output_scores=True, begin_suppress_tokens=None
... )
>>> print(outputs.scores[1][0, 50256])
>>> print(outputs.scores[0][0, 50256])
tensor(11.2027)
```
"""
Expand Down Expand Up @@ -1695,7 +1701,7 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
>>> # If we disable `suppress_tokens`, we can generate it.
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, suppress_tokens=None)
>>> print(outputs.scores[1][0, 1])
tensor(5.7738)
tensor(6.0678)
```
"""

Expand All @@ -1714,36 +1720,6 @@ class ForceTokensLogitsProcessor(LogitsProcessor):
indices that will be forced before generation. The processor will set their log probs to `inf` so that they are
sampled at their corresponding index. Originally created for
[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).

Examples:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This processor is going to be removed in v4.40, so I didn't want to spend time fixing the test :D

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cheeky :D

```python
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset

>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")

>>> # This Whisper model forces the generation to start with `50362` at the first position by default, i.e.
>>> # `"forced_decoder_ids": [[1, 50362]]`. This means all other tokens are masked out.
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
>>> print(
... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
... )
True
>>> print(outputs.scores[0][0, 50362])
tensor(0.)

>>> # If we disable `forced_decoder_ids`, we stop seeing that effect
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, forced_decoder_ids=None)
>>> print(
... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
... )
False
>>> print(outputs.scores[0][0, 50362])
tensor(19.3140)
```
Comment on lines -1717 to -1746
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove the example here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This processor is going to be removed in v4.40, so I didn't want to spend time fixing the test :D

:)

"""

def __init__(self, force_token_map: List[List[int]], _has_warned: Optional[bool] = False):
Expand Down
8 changes: 2 additions & 6 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
import warnings
import zlib
Expand Down Expand Up @@ -474,11 +473,8 @@ def generate(
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
FutureWarning,
)
# 1. copy generation config
if generation_config is None:
generation_config = copy.deepcopy(self.generation_config)
else:
generation_config = copy.deepcopy(generation_config)
# 1. prepare generation config
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lines above imply there's a self.generation_config which should be used if generation_config is None

Copy link
Member Author

@gante gante Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._prepare_generation_config() does precisely that:

generation_config = self.generation_config

It is a more complex version of this original if/else that preserves additional backward (and forward!) compatibility features of generate :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function from the main generate body (_prepare_generation_config) pulls generation parameterization from kwargs into generation_config.

Some Whisper-based doctests were incorrect without this functionality.


# 2. set global generate variables
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
Expand Down
Loading