-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: store special token tensors under a unique variable name #31980
Conversation
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model | ||
if not using_model_generation_config: | ||
if generation_config.bos_token_id is None: | ||
generation_config.bos_token_id = self.generation_config.bos_token_id | ||
if generation_config.eos_token_id is None: | ||
generation_config.eos_token_id = self.generation_config.eos_token_id | ||
if generation_config.pad_token_id is None: | ||
generation_config.pad_token_id = self.generation_config.pad_token_id | ||
if generation_config.decoder_start_token_id is None: | ||
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is equivalent to the changes in this PR, which are better suited to this function -- handling retrocompatibility wrt config files
@@ -3196,6 +3196,39 @@ def test_assisted_decoding_in_gpu_cpu(self): | |||
) | |||
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20) | |||
|
|||
def test_special_tokens_fall_back_to_model_default(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test was missing in #31254 😛
3e4948d
to
d21329f
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow nice finding!
src/transformers/generation/utils.py
Outdated
# NOTE: this must be written into a different attribute name than the one holding the original special tokens | ||
# (in their non-tensor form), in order to enable end-to-end compilation. See | ||
# https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations | ||
generation_config.bos_token_tensor = bos_token_tensor | ||
generation_config.eos_token_tensor = eos_token_tensor | ||
generation_config.pad_token_tensor = eos_token_tensor | ||
generation_config.decoder_start_token_tensor = decoder_start_token_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it work to have a property, eos_token_id with _eos_token_tensor underlying? When you set it you cast to tensor format. might be simpler in general ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uhmmm @property
will further tangle us with state/python classes, which I'm not a fan of for compile purposes 🤔
I am going to rename the tensor variables from xxx_token_tensor
to _xxx_token_tensor
though, to help with readability!
@@ -1539,75 +1539,43 @@ def generate( | |||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | |||
|
|||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some more musicgen standardization :) (=copy paste new, upgraded patterns from the main generate
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
confirmed: slow tests have the same failures as on main
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An interesting behavior from compile
. Thanks for handling!
What does this PR do?
See this comment in the end-to-end generation PR: #30788 (comment)
Problem TL;DR
Compiled code can't create tensors as
x=torch.tensor(x)
(tensor output var name = non-tensor input var name), the graph overwrites the argument (non-tensor) with the output (tensor), causing the graph to be missing the input to this node the 2nd time it is called (because it was overwritten). It is a known limitation oftorch.compile
.Our code that converts the special tokens into tensors falls into this pattern.
Solution
Write the special tokens converted into tensors under a new variable name. A bit annoying, we will now have e.g.
pad_token_id
(integer) andpad_token_tensor
(tensor) in thegeneration_config
object throughoutgenerate
.Note: We can't change the variable name where these tokens are read from, which would be much cleaner. This is because:
generate
(atGenerationConfig
creation time)as such, even if we create an auxiliary attribute with a different name to read from, the original attribute will be overwritten, leading to the original issue 😢