Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: store special token tensors under a unique variable name #31980

Merged
merged 5 commits into from
Jul 22, 2024

Conversation

gante
Copy link
Member

@gante gante commented Jul 15, 2024

What does this PR do?

See this comment in the end-to-end generation PR: #30788 (comment)

Problem TL;DR

Compiled code can't create tensors as x=torch.tensor(x) (tensor output var name = non-tensor input var name), the graph overwrites the argument (non-tensor) with the output (tensor), causing the graph to be missing the input to this node the 2nd time it is called (because it was overwritten). It is a known limitation of torch.compile.

Our code that converts the special tokens into tensors falls into this pattern.

Solution

Write the special tokens converted into tensors under a new variable name. A bit annoying, we will now have e.g. pad_token_id (integer) and pad_token_tensor (tensor) in the generation_config object throughout generate.

Note: We can't change the variable name where these tokens are read from, which would be much cleaner. This is because:

  1. in end-to-end compilation we can't deepcopy
  2. the attribute is set outside generate (at GenerationConfig creation time)

as such, even if we create an auxiliary attribute with a different name to read from, the original attribute will be overwritten, leading to the original issue 😢

Comment on lines +1395 to +1404
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
if not using_model_generation_config:
if generation_config.bos_token_id is None:
generation_config.bos_token_id = self.generation_config.bos_token_id
if generation_config.eos_token_id is None:
generation_config.eos_token_id = self.generation_config.eos_token_id
if generation_config.pad_token_id is None:
generation_config.pad_token_id = self.generation_config.pad_token_id
if generation_config.decoder_start_token_id is None:
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is equivalent to the changes in this PR, which are better suited to this function -- handling retrocompatibility wrt config files

@@ -3196,6 +3196,39 @@ def test_assisted_decoding_in_gpu_cpu(self):
)
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)

def test_special_tokens_fall_back_to_model_default(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was missing in #31254 😛

@gante gante force-pushed the different_var_special_tokens branch from 3e4948d to d21329f Compare July 15, 2024 16:40
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow nice finding!

Comment on lines 1563 to 1569
# NOTE: this must be written into a different attribute name than the one holding the original special tokens
# (in their non-tensor form), in order to enable end-to-end compilation. See
# https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
generation_config.bos_token_tensor = bos_token_tensor
generation_config.eos_token_tensor = eos_token_tensor
generation_config.pad_token_tensor = eos_token_tensor
generation_config.decoder_start_token_tensor = decoder_start_token_tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it work to have a property, eos_token_id with _eos_token_tensor underlying? When you set it you cast to tensor format. might be simpler in general ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uhmmm @property will further tangle us with state/python classes, which I'm not a fan of for compile purposes 🤔

I am going to rename the tensor variables from xxx_token_tensor to _xxx_token_tensor though, to help with readability!

@@ -1539,75 +1539,43 @@ def generate(
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
Copy link
Member Author

@gante gante Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some more musicgen standardization :) (=copy paste new, upgraded patterns from the main generate)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirmed: slow tests have the same failures as on main

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An interesting behavior from compile. Thanks for handling!

@gante gante merged commit c38c55f into huggingface:main Jul 22, 2024
23 checks passed
@gante gante deleted the different_var_special_tokens branch July 22, 2024 13:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants