Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reducing memory usage: removing useless logits computation in generate() #31292

Merged
merged 35 commits into from
Aug 23, 2024

Conversation

Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Jun 6, 2024

What does this PR do?

This is the PR related to the discussion in #30860.
I followed was has been done in Jamba and added support for the num_logits_to_keep argument in forward(). However, even if this argument is None, the logits will only be upcasted to float if labels are passed (in order to accurately compute the loss). Otherwise, the upcasting only happen in the generate() functions.

For now, I only modified Llama and Mistral, but if you agree on the changes I will add support for more models.

Benchmarks

Here I provide some benchmarks of the peak memory usage. For each input size, I generated 10 additional tokens.
Of course, since for few additional tokens the memory peak scales only with the first forward pass (at least when computing the whole logits matrix), and that the first forward scales linearly with input size and batch size (with new attention algorithms), the gain is actually constant for all input sizes and generation methods (except for contrastive search, which artificially increase the batch size after the first forward, thus the memory usage is slightly different). However, I still provide results for all generation methods here for completeness.

Basically we get:
Llama3 8B -> MIND-BLOWING 3.62 memory usage reduction factor (due to large vocabulary)
Llama2 7B -> 1.17 reduction factor
Mistral 7B -> 1.32 reduction factor

Note that the memory reduction shown here is on top of whatever gains #30536 already provides for small new additional tokens, as I am comparing memory with the main transformers branch after it was merged. It integrates very nicely with that last PR, as the last one was providing most benefits when generating more tokens, and this one provides gains for small new number of tokens.

greedy.pdf
sample.pdf
beam sample.pdf
beam search.pdf
group beam search.pdf
contrastive search.pdf

Here is a link to the benchmark script: https://gist.github.com/Cyrilvallez/92f48e402aa2968c854a8128796f50c3

Who can review?

@ArthurZucker @gante Let me know what you think about the proposed changes!

@amyeroberts
Copy link
Collaborator

cc @ArthurZucker @gante

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looks quite alright IMO!

Comment on lines 1184 to 1191
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
if num_logits_to_keep is None:
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
else:
logits = [
F.linear(hidden_states[:, -num_logits_to_keep:, :], lm_head_slices[i])
for i in range(self.config.pretraining_tp)
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not update this, pretraining TP is really never used in practice. I'll deprecate it

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for further reducing the memory needs @Cyrilvallez 💛

num_logits_to_keep is not the prettiest interface, but I can't think of a better one (as discussed in the PR that introduced it).

I'm happy with the PR with the exception of BC handling

@gante
Copy link
Member

gante commented Jun 18, 2024

btw, a ratio of 3x lower peak memory consumption is 🔥 🔥 🔥

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍

@Cyrilvallez
Copy link
Member Author

Cyrilvallez commented Jun 21, 2024

I just added the change to more models and rebased to avoid conflicts with new commits in main!
For Cohere-based models, I most notably computed a memory gain ratio of 6.68 due to the very large 256k vocabulary size 🚀🔥

Last thing to take into account is your comment about the signature @ArthurZucker but not sure I understood correctly what you wanted to do 🤓

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM

Comment on lines 1141 to 1144
if num_logits_to_keep is None:
logits = self.lm_head(hidden_states).float()
else:
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can default num_logits to keep to 0 to always slice (no codepaths)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uhmmm does self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() with the default value 0? 🤔

@ArthurZucker
Copy link
Collaborator

Make sure to rebase as the state of the main branch was changed quite a bit!

@Cyrilvallez
Copy link
Member Author

Will do! However, when playing with torch.compile, I noticed that adding a logger.warning_once() in the forward breaks the graph with the following error: Unsupported: call_method UserDefinedObjectVariable(Logger) warning_once [ConstantVariable()] {}. This is with PyTorch latest version (2.3.1). So I will make sure to change that/make it compatible as well.

@Cyrilvallez
Copy link
Member Author

DO NOT MERGE YET
Everything else is good, but still need to sort out the logger.warning_once/compile issue

@Cyrilvallez Cyrilvallez force-pushed the logits-dtype branch 3 times, most recently from 9855f62 to f4da824 Compare July 17, 2024 11:46
@Cyrilvallez
Copy link
Member Author

@ArthurZucker @gante everything is now ready.
From my tests, it seems like compile does not support any print-like functionality at the moment, either from print, logger or warnings.
I first wanted to add a logger.warning_once_compile_safe function which I thought would simplify things and come in handy in the future as well, but couldn't because it needs to import torch in the logging module which breaks things.
So I just added a compile check everywhere.

@ringohoffman
Copy link
Contributor

@ArthurZucker is this planned for review this week? I’m pretty eager to consume this PR.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jul 26, 2024

Yes! Reviewing asap!

@Oxi84
Copy link

Oxi84 commented Aug 1, 2024

Looking forward to testing this out, gemma2 uses a lot of memory otherwise and is a top model.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well LGTM one thing missing is a test in the mixins !

logits = self.lm_head(hidden_states)
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
"Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe let's comment that we need .float() for full precision soft

@ringohoffman
Copy link
Contributor

Hey @Cyrilvallez, thanks for your work. Just checking in regarding this PR. Do you have a plan to finish it up some time soon? I'm very excited for it to land!

@Cyrilvallez
Copy link
Member Author

Hi @ringohoffman, don't worry I am not forgetting about this 😉 I'm currently on vacation so I will try to wrap it up quickly end of August when I come back if I have time. Worst case scenario, it will be ready mid-September.

In the meantime, you can install transformers from my fork if you want to already benefit from it (pip install git+https://github.com/Cyrilvallez/transformers@logits-dtype). Or even better, you can clone my fork and rebase it on transformers/main to get all the new stuff + this PR.

@Boubou78000
Copy link

Does this PR actually fixes gemma2 or just Gemma?

@Cyrilvallez
Copy link
Member Author

Gemma2 was not released yet when I started this, but don't worry I will add it as well, it's on the roadmap 🤗

@Cyrilvallez
Copy link
Member Author

Cyrilvallez commented Aug 21, 2024

@ArthurZucker I added support for Gemma2 as well as tests, ready for last review 🤗
Red CIs are not related to the PR

@Cyrilvallez
Copy link
Member Author

No worries! All good on the CIs and ready to be merged 🤗

@gante gante merged commit 22e6f14 into huggingface:main Aug 23, 2024
23 checks passed
@ringohoffman
Copy link
Contributor

Congrats, @Cyrilvallez!

When is this planned to be released? @ArthurZucker @gante

@gante
Copy link
Member

gante commented Aug 23, 2024

@ringohoffman our rule of thumb is to release every month, so it should be in ~2 weeks 🤗

@ringohoffman ringohoffman mentioned this pull request Oct 2, 2024
5 tasks
ringohoffman added a commit to ringohoffman/transformers that referenced this pull request Oct 14, 2024
ArthurZucker pushed a commit that referenced this pull request Oct 18, 2024
* Only cast logits to float when computing loss

Some misses from #31292 and #33902

* Move logits.float() into existing if labels is not None branch
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request Oct 21, 2024
* Only cast logits to float when computing loss

Some misses from huggingface#31292 and huggingface#33902

* Move logits.float() into existing if labels is not None branch
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…e() (huggingface#31292)

* Add .float() in all generation methods logit outputs

* Switch float-casting of logits to training only for main models

* Add `num_logits_to_keep` in Llama and add it by default in generate

* Apply style

* Add num_logits_to_keep as arg in prepare_input_for_generation

* Add support for Mistral

* Revert models except llama and mistral

* Fix default None value in _supports_num_logits_to_keep()

* Fix dimension of dummy input

* Add exception for prophetnet in _supports_num_logits_to_keep()

* Update _supports_num_logits_to_keep() to use inspect.signature()

* Add deprecation cycle + remove modification with pretraining_tp

* Apply style

* Add most used models

* Apply style

* Make `num_logits_to_keep` an int in all cases to remove if-else clause

* Add compile check for the warning

* Fix torch versions

* style

* Add gemma2

* Update warning version

* Add comment about .float operations in generation utils

* Add tests in GenerationTesterMixin and ModelTesterMixin

* Fix batch size for assisted decoding in tests

* fix small issues in test

* refacor test

* fix slicing removing dim issue

* Add nemotron support (should fix check-copy issue in CIs)

* Trigger new CIs

* Trigger new CIs

* Bump version

* Bump version in TODO

* Trigger CIs

* remove blank space

* Trigger CIs
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* Only cast logits to float when computing loss

Some misses from huggingface#31292 and huggingface#33902

* Move logits.float() into existing if labels is not None branch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants