-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix assisted decoding #31401
fix assisted decoding #31401
Conversation
The failed CIs seem not related to my changes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @jiqing-feng! Thank you for opening this PR 🤗
To the best of my knowledge, the changes you're suggesting should not be needed. As such, I've asked a few questions below to understand why we need these changes :)
Hi @gante . Sorry for not making it clear. Could you run this script: import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "meta-llama/Llama-2-7b-chat-hf"
assistant_model_id = "Felladrin/Llama-68M-Chat-v1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_model_id, torch_dtype=torch.bfloat16).to("cpu")
prompt = "Assisted decoding is"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
model.generate(**inputs, assistant_model=assistant_model, max_new_tokens=8, min_new_tokens=8, do_sample=False) It will get the error Full traceback
|
I would like to add a test for this. Do you know where I should add this test? Thx! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense, thank you for digging deeper and iterating @jiqing-feng ! 💛
Regarding tests: it's a bit tricky to test two devices on our CI AFAIK 🤔 @amyeroberts do you have suggestions on how to test it? [TL;DR @jiqing-feng found that assisted generation fails if the two models are on different devices, because the special tokens are copied from the main model to the assistant model]
I think we can just run the test on a device with GPU; there is almost no limitation for CPU because we can run a very tiny model on CPU just for functionality. |
@gante There's certain tests in our suite which require multiple devices e.g. test_model_parallelization, which we can denote with the In this case, I'd suggest having two tests, one for the single accelerator case, and another which only runs in the multi device case. |
derp, ofc a GPU is enough (which has a CPU paired up), what a brain fart on my end :D @jiqing-feng could you add two tests like the script in this comment of yours to this file? More precisely:
|
Hi @gante . I have added the tests, could you please take a review? Thx! BTW, the failed CIs seem not related to my changes |
Hi @amyeroberts. Could you please take a review? The failed CIs are not related to my changes :) |
@jiqing-feng Regarding the failing tests, could you rebase on main to include upstream changes? This should resolve the failures on CI Could you also run and share the output of executing the following in a multi-gpu environment:
|
@jiqing-feng rebasing the PR should get CI green 🤗 |
Hi @amyeroberts . I run the 2 tests individually and got passed, see I also run your command and got the following output |
Hi @amyeroberts . Do you need more actions before merging? Please let me know, thx! |
Hi @amyeroberts @gante . I think this PR should be ready to merge :) |
@jiqing-feng OK, sorry, I think I messed up with the pytest command. Could you try this instead:
|
All passed |
Hi @amyeroberts . The failed CIs are not relate to my changes, would you please review my changes? |
Hi @amyeroberts @gante , would you please help to merge this PR? Thx! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
Hi @jiqing-feng, we had to wait for somethings to be resolved upstream and to wait for a new CI run (which I triggered last night) |
Hi @gante . This PR is to fix the assisted decoding when the model and assistant model are on different devices.
It can be easily reproduced by: