Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline: simple API for assisted generation #34504

Merged
merged 14 commits into from
Jan 8, 2025

Conversation

gante
Copy link
Member

@gante gante commented Oct 30, 2024

What does this PR do?

Assisted generation + pipelines has a bad UX at the moment: the user must manually load the assistant model (and the assistant tokenizer, if applicable), which defeats the point of being simple to use.

This PR adds the ability to specify an assistant checkpoint at pipeline definition time -- the pipeline will take care of the rest 🤗

⚠️ While the feature was added for all pipelines that call .generate(), I haven't added a test on all of them. Many pipelines don't forward kwargs properly to .generate() which makes testing this transparent [same output, similar runtime] feature hard -- the best way to confirm assisted generation is running is by passing incompatible flags to .generate() to make it crash.

Example usage

from transformers import pipeline
import torch

pipe = pipeline(
    "text-generation",
    model="meta-llama/Llama-3.1-8B",
    assistant_model="meta-llama/Llama-3.2-1B",  # This extra line is all that's needed!
    torch_dtype=torch.bfloat16
)
pipe_output = pipe("Once upon a time, ", max_new_tokens=50, do_sample=False)
print(pipe_output[0]["generated_text"])

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante changed the title mvp Pipeline: simple API for assisted generation Oct 30, 2024
@gante gante requested a review from LysandreJik October 30, 2024 16:31
@@ -416,16 +416,6 @@ Assisted decoding assumes the main and assistant models have the same tokenizer,
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs.
To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation).

#### Universal Assisted Decoding
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current version of the docs had the normal assisted generation docs under Universal Assisted Decoding [modification to the original technique to support different tokenizers]

Most of the diff is to isolate Universal Assisted Decoding


```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
<Tip>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new docs: an example of how to use assisted generation with pipelines

@@ -347,7 +347,6 @@ def generate(
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a legacy warning (before I joined 👀 ) that
a) is noisy
b) doesn't affect generation other than the ability to try to infer the attention mask when it is not passed

In the specific case of assisted generation, it was emitted once per assistant model call, so many times 😅 More harmful than useful.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strong agree on removing this one!

Comment on lines -195 to -202
if len(stop_sequence_ids) > 1:
warnings.warn(
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
" the stop sequence will be used as the stop sequence string in the interim."
)
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(no longer true)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we support this for sequences of IDs as well as stop strings?

Copy link
Member Author

@gante gante Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, generate supports using multiple token IDs as stopping criteria!

In this particular case, this warning predates the introduction of multiple token IDs as stopping criteria in generate :) So the warning is no longer true (and hasn't been for a while)

@LysandreJik
Copy link
Member

Thanks for the nice PR @gante!

@Rocketknight1, can you give it a first look as the pipeline owner? Thanks!

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very clean overall! The changes to pipeline code are relatively small, because all the actual action happens in generate(), so this PR really just takes care of forwarding assistant models/tokenizers to that method.

@@ -347,7 +347,6 @@ def generate(
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strong agree on removing this one!

Comment on lines -195 to -202
if len(stop_sequence_ids) > 1:
warnings.warn(
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
" the stop sequence will be used as the stop sequence string in the interim."
)
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we support this for sequences of IDs as well as stop strings?

@require_torch
def test_pipeline_assisted_generation(self):
"""Tests that we can run assisted generation in the pipeline"""
model = "distilbert/distilgpt2"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model = "distilbert/distilgpt2"
model = "distilbert/distilgpt2"

Distilgpt2 is still relatively large for a non-slow test when we're just checking for errors, rather than comparing outputs! Maybe there's a tiny-random model we can use?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done 👍 (and confirmed that it works)

Comment on lines +453 to +454
_, loaded_assistant_model = infer_framework_load_model(assistant_model, config=assistant_config)
loaded_assistant_model = loaded_assistant_model.to(device=model.device, dtype=model.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

infer_framework_load_model is framework-agnostic code, but .to() is PyTorch-only. Maybe we should add a warning if a user tries to use this with TF, since it's not supported at all?

Copy link
Member Author

@gante gante Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the following after checking that assistant_model is not None, and before the commented lines. In a nutshell, the cases where .to is not supported are caught in advance

if not isinstance(model, PreTrainedModel):
    raise ValueError(
        "Assisted generation, triggered by the `assistant_model` argument, is only available for "
        "`PreTrainedModel` model instances. For instance, TF or JAX models are not supported."
    )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, looks good!

@danielkorat
Copy link
Contributor

@gante will this PR make it to next release? 🙏

@danielkorat
Copy link
Contributor

Hi @LysandreJik 👋
Can we try to merge this before the next release?
This will drastically shorten the code examples we used in the HF blog and social posts.
Thanks!

@danielkorat
Copy link
Contributor

hi @gante, could you please merge when free? 🙏

@gante
Copy link
Member Author

gante commented Jan 7, 2025

@Rocketknight1 PR comments addressed 🫡 Let me know if you are happy with the PR!

Assuming CI is green: do I need to ping more folks for review?

@Rocketknight1
Copy link
Member

@gante yes, you can go ahead and merge this whenever you're happy!

@gante gante merged commit 76da6ca into huggingface:main Jan 8, 2025
25 checks passed
@gante gante deleted the pipeline_assistant branch January 8, 2025 17:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants