-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pipeline: simple API for assisted generation #34504
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -416,16 +416,6 @@ Assisted decoding assumes the main and assistant models have the same tokenizer, | |||
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs. | |||
To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation). | |||
|
|||
#### Universal Assisted Decoding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current version of the docs had the normal assisted generation docs under Universal Assisted Decoding
[modification to the original technique to support different tokenizers]
Most of the diff is to isolate Universal Assisted Decoding
|
||
```python | ||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer | ||
<Tip> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new docs: an example of how to use assisted generation with pipelines
@@ -347,7 +347,6 @@ def generate( | |||
eos_token_id = generation_config.eos_token_id | |||
if isinstance(eos_token_id, list): | |||
eos_token_id = eos_token_id[0] | |||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a legacy warning (before I joined 👀 ) that
a) is noisy
b) doesn't affect generation other than the ability to try to infer the attention mask when it is not passed
In the specific case of assisted generation, it was emitted once per assistant model call, so many times 😅 More harmful than useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strong agree on removing this one!
if len(stop_sequence_ids) > 1: | ||
warnings.warn( | ||
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of" | ||
" the stop sequence will be used as the stop sequence string in the interim." | ||
) | ||
generate_kwargs["eos_token_id"] = stop_sequence_ids[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(no longer true)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we support this for sequences of IDs as well as stop strings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, generate
supports using multiple token IDs as stopping criteria!
In this particular case, this warning predates the introduction of multiple token IDs as stopping criteria in generate
:) So the warning is no longer true (and hasn't been for a while)
Thanks for the nice PR @gante! @Rocketknight1, can you give it a first look as the pipeline owner? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems very clean overall! The changes to pipeline code are relatively small, because all the actual action happens in generate()
, so this PR really just takes care of forwarding assistant models/tokenizers to that method.
@@ -347,7 +347,6 @@ def generate( | |||
eos_token_id = generation_config.eos_token_id | |||
if isinstance(eos_token_id, list): | |||
eos_token_id = eos_token_id[0] | |||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strong agree on removing this one!
if len(stop_sequence_ids) > 1: | ||
warnings.warn( | ||
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of" | ||
" the stop sequence will be used as the stop sequence string in the interim." | ||
) | ||
generate_kwargs["eos_token_id"] = stop_sequence_ids[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we support this for sequences of IDs as well as stop strings?
@require_torch | ||
def test_pipeline_assisted_generation(self): | ||
"""Tests that we can run assisted generation in the pipeline""" | ||
model = "distilbert/distilgpt2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model = "distilbert/distilgpt2" | |
model = "distilbert/distilgpt2" |
Distilgpt2 is still relatively large for a non-slow test when we're just checking for errors, rather than comparing outputs! Maybe there's a tiny-random model we can use?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done 👍 (and confirmed that it works)
_, loaded_assistant_model = infer_framework_load_model(assistant_model, config=assistant_config) | ||
loaded_assistant_model = loaded_assistant_model.to(device=model.device, dtype=model.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
infer_framework_load_model
is framework-agnostic code, but .to()
is PyTorch-only. Maybe we should add a warning if a user tries to use this with TF, since it's not supported at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the following after checking that assistant_model
is not None
, and before the commented lines. In a nutshell, the cases where .to
is not supported are caught in advance
if not isinstance(model, PreTrainedModel):
raise ValueError(
"Assisted generation, triggered by the `assistant_model` argument, is only available for "
"`PreTrainedModel` model instances. For instance, TF or JAX models are not supported."
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, looks good!
@gante will this PR make it to next release? 🙏 |
Hi @LysandreJik 👋 |
hi @gante, could you please merge when free? 🙏 |
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
@Rocketknight1 PR comments addressed 🫡 Let me know if you are happy with the PR! Assuming CI is green: do I need to ping more folks for review? |
@gante yes, you can go ahead and merge this whenever you're happy! |
What does this PR do?
Assisted generation + pipelines has a bad UX at the moment: the user must manually load the assistant model (and the assistant tokenizer, if applicable), which defeats the point of being simple to use.
This PR adds the ability to specify an assistant checkpoint at pipeline definition time -- the pipeline will take care of the rest 🤗
.generate()
, I haven't added a test on all of them. Many pipelines don't forward kwargs properly to.generate()
which makes testing this transparent [same output, similar runtime] feature hard -- the best way to confirm assisted generation is running is by passing incompatible flags to.generate()
to make it crash.Example usage