-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Assisted decoding results are not correct #30413
Comments
Related to (#30042) |
@jiqing-feng , the fix was merged on main. You can update transformers with Closing issue as resolved :) |
It's not exactly the same in the last few tokens, but better. Is it reasonable with a little difference? |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
promtpt = """
You are chatbot. The conversion history is givenbetween ``` ```. Each interlocutor starts with "gpt: " or "human: " and ends with "@@@". You play "gpt". You need to reply to "human". conversation history:```system: *This chat conversation is shared from [**TypingMind.com**](https://typingmind.com)* @@@ human: Create a travel plan for a Family with small kids from London to Belgrade tra
"""
device = "cuda:1"
model_id = "meta-llama/Llama-2-7b-chat-hf"
as_model_id = "Felladrin/Llama-68M-Chat-v1"
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(device)
as_model = AutoModelForCausalLM.from_pretrained(as_model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(promtpt, return_tensors="pt").to(device)
generate_kwargs = {"do_sample": False, "num_beams": 1, "max_new_tokens": 256}
print("greedy search")
outputs = model.generate(**inputs, **generate_kwargs)
print(outputs)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
print("assisted decoding")
outputs = model.generate(**inputs, assistant_model=as_model, **generate_kwargs)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
print(outputs) output:
Found mismatch when output length is long. |
@jiqing-feng After a bit of exploration I do not see any bugs in the way assisted decoding is passing in arguments. My guess is that the problem comes from small numerical precision errors that are accumulated over generation timesteps. In other words, for greedy decoding we always have 1 more token when generating, so the calculation of key/value is actually a vector-matrix multiplication. However for assisted generation it's always a matrix-matrix multiplication due to having large number of candidate tokens verified. So my opinion is that torch internally handles those differently with slightly different operation's order, which leads to error accumulation. cc @gante do you have any other ideas why this happens? |
It is reasonable, thanks : ) |
@jiqing-feng Yes, numerical issues will cause assisted generation to pick a different token from time to time. It's the exact same issue as with batched generation or the use of KV caches :) 👉 you can read more about the issue here |
System Info
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.
transformers
version: 4.40.0.dev0- distributed_type: MULTI_CPU
- mixed_precision: bf16
- use_cpu: True
- debug: False
- num_processes: 2
- machine_rank: 0
- num_machines: 1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- ipex_config: {'ipex': False}
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
Who can help?
@gante
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
The outputs:
Expected behavior
Hi @gante
The outputs should be the same, but the assisted decoding is incorrect. I suppose there are some arguments mistake caused this issue, I've checked it and found the candidate generator has the same output as greedy search but the target model (self) forward results are incorrect. Would you please help me to figure out the issue? Thx!
BTW, I see that the
cache_position
is inconsistent, but I don't know the correct format.The text was updated successfully, but these errors were encountered: