-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: consistently handle special tokens as tensors #29788
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
a746766
to
7863c5d
Compare
@@ -567,24 +570,6 @@ def _prepare_decoder_input_ids_for_generation( | |||
|
|||
return decoder_input_ids, model_kwargs | |||
|
|||
def _get_decoder_start_token_id( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic of this function is now within _prepare_special_tokens
, which preprocesses all special tokens
@@ -1221,6 +1208,55 @@ def _prepare_generation_config( | |||
|
|||
return generation_config, model_kwargs | |||
|
|||
def _prepare_special_tokens( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ALL preprocessing logic for the special tokens now resides in this function 🧹
@@ -1221,6 +1208,55 @@ def _prepare_generation_config( | |||
|
|||
return generation_config, model_kwargs | |||
|
|||
def _prepare_special_tokens( | |||
self, generation_config: GenerationConfig, kwargs_has_attention_mask: Optional[bool] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kwargs_has_attention_mask
is an optional argument so we can use this function in tests, to prepare special tokens.
eos_token_id = [eos_token_id] | ||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | ||
|
||
if not isinstance(pad_token_id, torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The decoding functions are backward compatible (for now), and we can still pass int
/list(int)
as special tokens.
The doctests in generate
test this.
@@ -2170,6 +2170,24 @@ def _maybe_initialize_input_ids_for_generation( | |||
break | |||
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id | |||
|
|||
def _get_decoder_start_token_id( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Musicgen (and its melody variant) have their own custom generate
, relying on this method.
I've intentionally not updated this custom generate
, to pressure us into moving towards a single generate function.
) -> torch.LongTensor: | ||
# No information for attention mask inference -> return default attention mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic rewritten in functions like this is torch.compile(..., fullgraph=True)
compatible 😉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this 😄
Co-authored-by: Raushan Turganbay <raushan.turganbay@alumni.nu.edu.kz>
7537207
to
dd5bf8c
Compare
let's merge #29956 first, so the diff here becomes much smaller (the EOS-as-stopping-criteria made the diff more elaborate) (Arthur -- don't review this one until that is merged, I'll ping you again) |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
To enable
torch.compile
withgenerate
, some special token-related operations have to be rewritten into torch operations. That requires special tokens to be tensors instead of integers or a list of integers. (See #29374 for a working prototype)This PR reworks special token usage in
generate
to consistently treat them as a tensor, as opposed to e.g. keeping track ofeos_token_id
in integer and in tensor form.👉 Review suggestion: start by reading
_prepare_special_tokens
and how it fits ingenerate
.Requirements before merging this PR:
Tests ran locally:
pytest --doctest-modules src/transformers/generation/logits_process.py -vv
), needs requirement to be merged firstpytest --doctest-modules src/transformers/generation/utils.py -vv
)RUN_SLOW=1 py.test tests/generation/test_utils.py -vv
)RUN_SLOW=1 py.test tests/test_cache_utils.py -vv
) -- same failures as inmain
RUN_SLOW=1 py.test tests/models/llama/test_modeling_llama.py -vv
)RUN_SLOW=1 py.test tests/models/whisper/test_modeling_whisper.py -vv
) -- same failures as inmain