Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add attention_mask and position_ids in assisted model #26892

Merged
merged 14 commits into from
Nov 10, 2023

Conversation

jiqing-feng
Copy link
Contributor

@jiqing-feng jiqing-feng commented Oct 18, 2023

Hi @gante

Do you think that we should also add assistant_attention_mask and assistant_position_ids in assisted_decoding? I see that the original model has attention_mask and position_ids(in most models) in the model inputs but the assistant model has no these kinds of input.

If you think it is okay to align the inputs of the original model and the assistant model, maybe we can find a more elegant way to integrate it. Thx!

@gante
Copy link
Member

gante commented Oct 25, 2023

Hi @jiqing-feng 👋

I agree in principle with the changes that you are proposing, but you probably need to do a few changes to make our CI go green :)

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Oct 29, 2023

Hi @gante . I use assistant_model.prepare_inputs_for_generation to get the inputs of the assistant model. The CI all goes green and I also tested on my several examples to make sure the outputs is correct. Would you please help me review it? Thx!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few nits -- after those are addressed, we're ready to merge :)

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
@jiqing-feng
Copy link
Contributor Author

Hi @gante . Would you please review it again? Thx!

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for iterating! 💛

else:
input_ids_len = assistant_inputs["input_ids"].shape[-1]

if input_ids_len not in (0, 1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if input_ids_len not in (0, 1):
if input_ids_len not in (1, 2):

@gante gante requested review from amyeroberts and removed request for amyeroberts November 2, 2023 15:54
@gante
Copy link
Member

gante commented Nov 2, 2023

@jiqing-feng Ah, actually I have two requests before asking for the green light of a core maintainer:

  1. There is a merge conflict, due to recent changes for a new model. If you're not able to sort it out, let me know :)
  2. Let's confirm that we haven't lost throughput with the changes (e.g. the assertion might be producing slowdowns). To test it, feel free to clone this folder, move there, and then run python benchmark_decoder_open.py facebook/opt-6.7b --aux-model facebook/opt-125m --dtype fp16 --num-samples 20 on main and on your branch. The execution times should be nearly identical! 🤗 If you have your own test script, feel free to use it instead -- just let us know of the numbers :)

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Nov 3, 2023

Hi @gante . I tested it on my CPU device since the GPU is unavailable to me. The new branch is a little faster (around 3%) than the main branch. The test script is as follows, feel free to test it on both GPU and CPU.

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time

prompt = "Speculative decoding is"
checkpoint = "bigscience/bloom-7b1"
assistant_checkpoint = "bigscience/bloom-560m"
device = "cpu"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to(device)

model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)

generation_kwargs = {"do_sample": False, "max_new_tokens": 64, "temperature": 1.0, "top_p": 1.0, "num_beams": 1}

assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)

for i in range(5):
    start = time.time()
    outputs = model.generate(**inputs, assistant_model=assistant_model, **generation_kwargs)
    end = time.time()
    new_tokens = outputs.shape[-1] - inputs["input_ids"].shape[-1]
    print(f"Assistant decoding latency per token is {(end-start)/new_tokens * 1000} ms")
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

@jiqing-feng jiqing-feng requested a review from gante November 7, 2023 01:36
@jiqing-feng
Copy link
Contributor Author

Hi @gante . I tested it on my CPU device since the GPU is unavailable to me. The new branch is a little faster (around 3%) than the main branch. The test script is as follows, feel free to test it on both GPU and CPU.

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time

prompt = "Speculative decoding is"
checkpoint = "bigscience/bloom-7b1"
assistant_checkpoint = "bigscience/bloom-560m"
device = "cpu"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to(device)

model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)

generation_kwargs = {"do_sample": False, "max_new_tokens": 64, "temperature": 1.0, "top_p": 1.0, "num_beams": 1}

assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)

for i in range(5):
    start = time.time()
    outputs = model.generate(**inputs, assistant_model=assistant_model, **generation_kwargs)
    end = time.time()
    new_tokens = outputs.shape[-1] - inputs["input_ids"].shape[-1]
    print(f"Assistant decoding latency per token is {(end-start)/new_tokens * 1000} ms")
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

Hi @gante . Could you have a look at this? Thx!

@gante
Copy link
Member

gante commented Nov 7, 2023

Hi @jiqing-feng

Running on my end (python benchmark_decoder_open.py facebook/opt-6.7b --aux-model facebook/opt-125m --dtype fp16 from this folder), I got

Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/joao/huggingface-demos/experiments/faster_generation/utils.py", line 99, in run_new_model
    new_outputs = run_prediction_loop(model, tokenizer, args.num_samples, args.temperature, aux_model)
  File "/home/joao/huggingface-demos/experiments/faster_generation/benchmark_decoder_open.py", line 35, in run_prediction_loop
    gen_out = model.generate(
  File "/home/joao/venvs/hf/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/joao/transformers/src/transformers/generation/utils.py", line 1736, in generate
    return self.assisted_decoding(
  File "/home/joao/transformers/src/transformers/generation/utils.py", line 4594, in assisted_decoding
    assistant_attention_mask = torch.cat(
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument tensors in method wrapper_CUDA_cat)

i.e. the newly generated masks that are appended must be created in the same device as the existing mask :)

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Nov 8, 2023

Hi @gante . Would you please try it again? It should be fixed and I also tested it on A100, the results and performance are exactly the same. BTW, the failed test seems not related to my changes.

@gante
Copy link
Member

gante commented Nov 8, 2023

@jiqing-feng perfect, all works well on my end.

Two related notes:

  1. The CI is indeed red for external reasons, waiting for this PR to get merged
  2. The diff shows that assistant_accepts_encoder_outputs (a recent addition in assisted generation for distil-whisper, to support assistants with shared encoders) is removed, which means your changes are not built on top of the latest version.

👉 you will need to rebase your changes to fix both issues, but only after the PR linked above gets merged. You may get minor rebase issues due to 2., but they should be trivial to fix

After that is done, I'll tag a core maintainer for a final quick check :)

@jiqing-feng
Copy link
Contributor Author

2. assistant_accepts_encoder_outputs

Hi @gante . I removed assistant_accepts_encoder_outputs because it is useless in my new changes, all inputs should be generated by assistant_model.prepare_inputs_for_generation.

@gante
Copy link
Member

gante commented Nov 8, 2023

Hi @gante . I removed assistant_accepts_encoder_outputs because it is useless in my new changes, all inputs should be generated by assistant_model.prepare_inputs_for_generation.

🤦 my apologies, you're absolutely right.

In that case, rebasing to get the CI green is all you need to do. Tagging a core maintainer for a quick final check :)

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to go, thank you for iterating with me 💛

(Note: results also validated on my end, no slowdown nor generative performance drop)

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Nov 9, 2023

@jiqing-feng perfect, all works well on my end.

Two related notes:

  1. The CI is indeed red for external reasons, waiting for this PR to get merged
  2. The diff shows that assistant_accepts_encoder_outputs (a recent addition in assisted generation for distil-whisper, to support assistants with shared encoders) is removed, which means your changes are not built on top of the latest version.

👉 you will need to rebase your changes to fix both issues, but only after the PR linked above gets merged. You may get minor rebase issues due to 2., but they should be trivial to fix

After that is done, I'll tag a core maintainer for a final quick check :)

Hi @gante . I see this PR you mentioned has been merged and my PR is already up to date, but some of the CI are still red.

@amyeroberts
Copy link
Collaborator

@jiqing-feng There were some unexpected failures because of new package releases - thankfully not related to this PR! They should now be resolved on main - rebasing should fix them here.

Hi @amyeroberts . #26892 (comment) could test the outputs before and after my changes. I guess you mean adding a test in the tests file to run the CI, if so, would you please tell me which file should be modified to add this test? Thx

Yes, I meant to add a test to the CI runs. It looks like it should be tested in tests/generation/test_utils.py - but I'll let @gante confirm

@gante
Copy link
Member

gante commented Nov 9, 2023

(woops, wrong button)

@gante gante closed this Nov 9, 2023
@gante gante reopened this Nov 9, 2023
@gante
Copy link
Member

gante commented Nov 9, 2023

@amyeroberts not sure if we can test this feature reliably: there are no output differences, since assisted generation always outputs what the main model dictates and this PR only modifies the assistant model's inputs to be more aligned with the main model's.

What we should see on average is a higher speedup with masked inputs, as the assistant model will receive the same inputs and thus has a higher chance of matching the main model, but that is far guaranteed for all calls. A speed test would be very flaky 🤔

@amyeroberts
Copy link
Collaborator

@gante I understand - I wasn't clear enough before. Really all I was looking for it to make sure that this can be safely used for different assistant models i.e. can I pass in a decoder-only model? How about encoder-decoder. So not speed or values, just API

@gante
Copy link
Member

gante commented Nov 9, 2023

@amyeroberts we do have Mixin tests (e.g.), so any issue regarding API should have been caught there :)

@amyeroberts
Copy link
Collaborator

@gante Sweet - in that case it's all good 👍 Re the failing tests - there's some PRs due to be merge which should (hopefully, this time) resolve the issues we've been having

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for adding!

@jiqing-feng
Copy link
Contributor Author

Hi, @gante @amyeroberts . All CI are green. I think it is time to merge : )

@gante gante merged commit 184f60d into huggingface:main Nov 10, 2023
2 checks passed
@gante
Copy link
Member

gante commented Nov 10, 2023

@jiqing-feng thank you for iterating with us and making transformers better 💛

@VsonicV
Copy link
Contributor

VsonicV commented Nov 14, 2023

@amyeroberts @jiqing-feng There are currently some unexpected CI failures caused by test_assisted_decoding_sample (see #27351 and #27450 ). Are they related to this recently merged PR? I can see from the testing log that this PR did not run those tests involving test_assisted_decoding_sample during CI checking. Thanks!

@jiqing-feng jiqing-feng deleted the assist branch November 15, 2023 01:19
@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Nov 15, 2023

@amyeroberts @jiqing-feng There are currently some unexpected CI failures caused by test_assisted_decoding_sample (see #27351 and #27450 ). Are they related to this recently merged PR? I can see from the testing log that this PR did not run those tests involving test_assisted_decoding_sample during CI checking. Thanks!

Hi, @VsonicV . Sorry for the failed CI. It is weird that I can successfully run pytest in my local repo (which has updated to origin/main). I see that your CI failed at blenderbot and pegasus, but I can pass the local test by running pytest locally. Would you please update your repo and rerun the CI? Thx!
If it still doesn't solve your problem, just revert my changes to see if this is my PR's error.

@VsonicV
Copy link
Contributor

VsonicV commented Nov 15, 2023

Hi, @jiqing-feng, thanks for the quick check. This happened exactly the same for me: I can run pytest tests/models/blenderbot/test_modeling_blenderbot.py, etc., successfully in my local up-to-date repo, but it failed in the CI checks. Moreoever, these CI failures not only happen at blenderbot and pegasus, it also happened for 'umt5' (in one of my previous CI tests), and switch_transformers and t5 in another recent PR (see #27450 ). I asked here because this is the only recent PR that seems related to test_assisted_decoding_sample, but maybe the problem is somewhere else. Thanks for the help anyways!

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Nov 15, 2023

Hi, @jiqing-feng, thanks for the quick check. This happened exactly the same for me: I can run pytest tests/models/blenderbot/test_modeling_blenderbot.py, etc., successfully in my local up-to-date repo, but it failed in the CI checks. Moreoever, these CI failures not only happen at blenderbot and pegasus, it also happened for 'umt5' (in one of my previous CI tests), and switch_transformers and t5 in another recent PR (see #27450 ). I asked here because this is the only recent PR that seems related to test_assisted_decoding_sample, but maybe the problem is somewhere else. Thanks for the help anyways!

I submitted a new PR, and all CI passed. Would you apply my PR and see if the CI is ok?

Furthermore, it is worth a try that update your repo by merging the origin/main and pushing these updates to rerun the CI.

@VsonicV
Copy link
Contributor

VsonicV commented Nov 15, 2023

@jiqing-feng Hi, thanks for this prompt fix! I will rebase my PR and re-do the CI checks after your new PR is merged. Fingers crossed!

@patrickvonplaten
Copy link
Contributor

This PR broke speculative decoding for Whisper, can we maybe revert it for now?

@patrickvonplaten
Copy link
Contributor

patrickvonplaten added a commit that referenced this pull request Nov 16, 2023
* Revert "add attention_mask and position_ids in assisted model (#26892)"

This reverts commit 184f60d.

* more debug
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
)

* add attention_mask and position_ids in assisted model

* fix bug

* fix attention mask

* fix attention_mask

* check assist inputs

* check assist input ids length

* fix assist model type

* set assist attention mask device
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
…ngface#27523)

* Revert "add attention_mask and position_ids in assisted model (huggingface#26892)"

This reverts commit 184f60d.

* more debug
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants