-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add attention_mask and position_ids in assisted model #26892
Conversation
Hi @jiqing-feng 👋 I agree in principle with the changes that you are proposing, but you probably need to do a few changes to make our CI go green :) |
Hi @gante . I use |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a few nits -- after those are addressed, we're ready to merge :)
Hi @gante . Would you please review it again? Thx! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for iterating! 💛
src/transformers/generation/utils.py
Outdated
else: | ||
input_ids_len = assistant_inputs["input_ids"].shape[-1] | ||
|
||
if input_ids_len not in (0, 1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if input_ids_len not in (0, 1): | |
if input_ids_len not in (1, 2): |
@jiqing-feng Ah, actually I have two requests before asking for the green light of a core maintainer:
|
Hi @gante . I tested it on my CPU device since the GPU is unavailable to me. The new branch is a little faster (around 3%) than the main branch. The test script is as follows, feel free to test it on both GPU and CPU. from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
prompt = "Speculative decoding is"
checkpoint = "bigscience/bloom-7b1"
assistant_checkpoint = "bigscience/bloom-560m"
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)
generation_kwargs = {"do_sample": False, "max_new_tokens": 64, "temperature": 1.0, "top_p": 1.0, "num_beams": 1}
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)
for i in range(5):
start = time.time()
outputs = model.generate(**inputs, assistant_model=assistant_model, **generation_kwargs)
end = time.time()
new_tokens = outputs.shape[-1] - inputs["input_ids"].shape[-1]
print(f"Assistant decoding latency per token is {(end-start)/new_tokens * 1000} ms")
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) |
Hi @gante . Could you have a look at this? Thx! |
Hi @jiqing-feng Running on my end (
i.e. the newly generated masks that are appended must be created in the same device as the existing mask :) |
Hi @gante . Would you please try it again? It should be fixed and I also tested it on A100, the results and performance are exactly the same. BTW, the failed test seems not related to my changes. |
@jiqing-feng perfect, all works well on my end. Two related notes:
👉 you will need to rebase your changes to fix both issues, but only after the PR linked above gets merged. You may get minor rebase issues due to 2., but they should be trivial to fix After that is done, I'll tag a core maintainer for a final quick check :) |
Hi @gante . I removed |
🤦 my apologies, you're absolutely right. In that case, rebasing to get the CI green is all you need to do. Tagging a core maintainer for a quick final check :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to go, thank you for iterating with me 💛
(Note: results also validated on my end, no slowdown nor generative performance drop)
Hi @gante . I see this PR you mentioned has been merged and my PR is already up to date, but some of the CI are still red. |
@jiqing-feng There were some unexpected failures because of new package releases - thankfully not related to this PR! They should now be resolved on main - rebasing should fix them here.
Yes, I meant to add a test to the CI runs. It looks like it should be tested in tests/generation/test_utils.py - but I'll let @gante confirm |
(woops, wrong button) |
@amyeroberts not sure if we can test this feature reliably: there are no output differences, since assisted generation always outputs what the main model dictates and this PR only modifies the assistant model's inputs to be more aligned with the main model's. What we should see on average is a higher speedup with masked inputs, as the assistant model will receive the same inputs and thus has a higher chance of matching the main model, but that is far guaranteed for all calls. A speed test would be very flaky 🤔 |
@gante I understand - I wasn't clear enough before. Really all I was looking for it to make sure that this can be safely used for different assistant models i.e. can I pass in a decoder-only model? How about encoder-decoder. So not speed or values, just API |
@amyeroberts we do have Mixin tests (e.g.), so any issue regarding API should have been caught there :) |
@gante Sweet - in that case it's all good 👍 Re the failing tests - there's some PRs due to be merge which should (hopefully, this time) resolve the issues we've been having |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for adding!
Hi, @gante @amyeroberts . All CI are green. I think it is time to merge : ) |
@jiqing-feng thank you for iterating with us and making |
@amyeroberts @jiqing-feng There are currently some unexpected CI failures caused by |
Hi, @VsonicV . Sorry for the failed CI. It is weird that I can successfully run pytest in my local repo (which has updated to origin/main). I see that your CI failed at |
Hi, @jiqing-feng, thanks for the quick check. This happened exactly the same for me: I can run |
I submitted a new PR, and all CI passed. Would you apply my PR and see if the CI is ok? Furthermore, it is worth a try that update your repo by merging the origin/main and pushing these updates to rerun the CI. |
@jiqing-feng Hi, thanks for this prompt fix! I will rebase my PR and re-do the CI checks after your new PR is merged. Fingers crossed! |
This PR broke speculative decoding for Whisper, can we maybe revert it for now? |
This reverts commit 184f60d.
Issue reported here: https://huggingface.co/openai/whisper-large-v3/discussions/20 |
…ngface#27523) * Revert "add attention_mask and position_ids in assisted model (huggingface#26892)" This reverts commit 184f60d. * more debug
Hi @gante
Do you think that we should also add
assistant_attention_mask
andassistant_position_ids
inassisted_decoding
? I see that the original model hasattention_mask
andposition_ids
(in most models) in the model inputs but the assistant model has no these kinds of input.If you think it is okay to align the inputs of the original model and the assistant model, maybe we can find a more elegant way to integrate it. Thx!