-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[generate] fix eos/pad id check on mps devices #31695
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
src/transformers/generation/utils.py
Outdated
@@ -1510,7 +1510,7 @@ def _tensor_or_none(token_kwargs, token_self, device=None): | |||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.") | |||
|
|||
# we can't infer attn mask if pad token is set to be eos token in model's generation config | |||
if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @gante Are there any reasons e.g. compilation for using torch.isin?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, compilation requires this torch.isin
💔 cc @sanchit-gandhi
imo, we should create an internal function containing this torch.isin workaround (that works on compile AND mps devices), and replace all torch.isin
call by this function
18837f2
to
1fc3e1e
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(see comment above)
@sanchit-gandhi lmk if you have the bandwidth to address the change :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(i'll iron out compilation issues later. merging this shouldn't break any existing feature :) )
What does this PR do?
Generation currently fails on
main
for mps devices:Traceback
=> this is due to the
torch.isin
operator not being implemented on torch mps. This PR removes thetorch.isin
operator from the main body of generation, while keeping compatibility with the eos/pad checks added in #31254.Following this PR, Gemma-2 (and other generate-compatible models) can be run on mps.