Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: consistently handle special tokens as tensors #29788

Closed
wants to merge 16 commits into from

Conversation

gante
Copy link
Member

@gante gante commented Mar 21, 2024

What does this PR do?

To enable torch.compile with generate, some special token-related operations have to be rewritten into torch operations. That requires special tokens to be tensors instead of integers or a list of integers. (See #29374 for a working prototype)

This PR reworks special token usage in generate to consistently treat them as a tensor, as opposed to e.g. keeping track of eos_token_id in integer and in tensor form.

👉 Review suggestion: start by reading _prepare_special_tokens and how it fits in generate.


Requirements before merging this PR:

Tests ran locally:

  • logits processors doctests (pytest --doctest-modules src/transformers/generation/logits_process.py -vv), needs requirement to be merged first
  • generate doctests (pytest --doctest-modules src/transformers/generation/utils.py -vv)
  • generate integration tests (RUN_SLOW=1 py.test tests/generation/test_utils.py -vv)
  • cache integration tests (RUN_SLOW=1 py.test tests/test_cache_utils.py -vv) -- same failures as in main
  • llama slow tests (RUN_SLOW=1 py.test tests/models/llama/test_modeling_llama.py -vv)
  • whisper slow tests (RUN_SLOW=1 py.test tests/models/whisper/test_modeling_whisper.py -vv) -- same failures as in main

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante changed the title Generate: special tokens as tensors Generate: consistently handle special tokens as tensors Mar 22, 2024
@gante gante force-pushed the special_tokens_as_tensors branch from a746766 to 7863c5d Compare March 25, 2024 12:03
@gante gante marked this pull request as ready for review March 25, 2024 16:30
@@ -567,24 +570,6 @@ def _prepare_decoder_input_ids_for_generation(

return decoder_input_ids, model_kwargs

def _get_decoder_start_token_id(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic of this function is now within _prepare_special_tokens, which preprocesses all special tokens

@@ -1221,6 +1208,55 @@ def _prepare_generation_config(

return generation_config, model_kwargs

def _prepare_special_tokens(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ALL preprocessing logic for the special tokens now resides in this function 🧹

@@ -1221,6 +1208,55 @@ def _prepare_generation_config(

return generation_config, model_kwargs

def _prepare_special_tokens(
self, generation_config: GenerationConfig, kwargs_has_attention_mask: Optional[bool] = None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kwargs_has_attention_mask is an optional argument so we can use this function in tests, to prepare special tokens.

eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None

if not isinstance(pad_token_id, torch.Tensor):
Copy link
Member Author

@gante gante Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decoding functions are backward compatible (for now), and we can still pass int/list(int) as special tokens.

The doctests in generate test this.

@@ -2170,6 +2170,24 @@ def _maybe_initialize_input_ids_for_generation(
break
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id

def _get_decoder_start_token_id(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Musicgen (and its melody variant) have their own custom generate, relying on this method.

I've intentionally not updated this custom generate, to pressure us into moving towards a single generate function.

) -> torch.LongTensor:
# No information for attention mask inference -> return default attention mask
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic rewritten in functions like this is torch.compile(..., fullgraph=True) compatible 😉

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this 😄

src/transformers/generation/utils.py Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/logits_process.py Show resolved Hide resolved
@gante gante force-pushed the special_tokens_as_tensors branch from 7537207 to dd5bf8c Compare March 29, 2024 13:06
@gante
Copy link
Member Author

gante commented Mar 29, 2024

let's merge #29956 first, so the diff here becomes much smaller (the EOS-as-stopping-criteria made the diff more elaborate)

(Arthur -- don't review this one until that is merged, I'll ping you again)

@gante gante removed the request for review from ArthurZucker March 29, 2024 17:47
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants