-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reducing memory usage: removing useless logits computation in generate() #31292
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Looks quite alright IMO!
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss | ||
if num_logits_to_keep is None: | ||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] | ||
else: | ||
logits = [ | ||
F.linear(hidden_states[:, -num_logits_to_keep:, :], lm_head_slices[i]) | ||
for i in range(self.config.pretraining_tp) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not update this, pretraining TP is really never used in practice. I'll deprecate it
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for further reducing the memory needs @Cyrilvallez 💛
num_logits_to_keep
is not the prettiest interface, but I can't think of a better one (as discussed in the PR that introduced it).
I'm happy with the PR with the exception of BC handling
btw, a ratio of 3x lower peak memory consumption is 🔥 🔥 🔥 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍
8cfce7b
to
4236c05
Compare
I just added the change to more models and rebased to avoid conflicts with new commits in main! Last thing to take into account is your comment about the signature @ArthurZucker but not sure I understood correctly what you wanted to do 🤓 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM
if num_logits_to_keep is None: | ||
logits = self.lm_head(hidden_states).float() | ||
else: | ||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can default num_logits to keep to 0 to always slice (no codepaths)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uhmmm does self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
with the default value 0? 🤔
Make sure to rebase as the state of the main branch was changed quite a bit! |
Will do! However, when playing with |
dc4a1bb
to
9623111
Compare
DO NOT MERGE YET |
9855f62
to
f4da824
Compare
@ArthurZucker @gante everything is now ready. |
@ArthurZucker is this planned for review this week? I’m pretty eager to consume this PR. |
Yes! Reviewing asap! |
Looking forward to testing this out, gemma2 uses a lot of memory otherwise and is a top model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well LGTM one thing missing is a test in the mixins !
logits = self.lm_head(hidden_states) | ||
if labels is None and not is_torchdynamo_compiling(): | ||
logger.warning_once( | ||
"Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" | |
"Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe let's comment that we need .float() for full precision soft
Hey @Cyrilvallez, thanks for your work. Just checking in regarding this PR. Do you have a plan to finish it up some time soon? I'm very excited for it to land! |
Hi @ringohoffman, don't worry I am not forgetting about this 😉 I'm currently on vacation so I will try to wrap it up quickly end of August when I come back if I have time. Worst case scenario, it will be ready mid-September. In the meantime, you can install transformers from my fork if you want to already benefit from it (pip install git+https://github.com/Cyrilvallez/transformers@logits-dtype). Or even better, you can clone my fork and rebase it on transformers/main to get all the new stuff + this PR. |
Does this PR actually fixes gemma2 or just Gemma? |
Gemma2 was not released yet when I started this, but don't worry I will add it as well, it's on the roadmap 🤗 |
1141198
to
e34d512
Compare
@ArthurZucker I added support for Gemma2 as well as tests, ready for last review 🤗 |
1b62c0e
to
7b1a26c
Compare
No worries! All good on the CIs and ready to be merged 🤗 |
Congrats, @Cyrilvallez! When is this planned to be released? @ArthurZucker @gante |
@ringohoffman our rule of thumb is to release every month, so it should be in ~2 weeks 🤗 |
Some misses from huggingface#31292 and huggingface#33902
* Only cast logits to float when computing loss Some misses from huggingface#31292 and huggingface#33902 * Move logits.float() into existing if labels is not None branch
…e() (huggingface#31292) * Add .float() in all generation methods logit outputs * Switch float-casting of logits to training only for main models * Add `num_logits_to_keep` in Llama and add it by default in generate * Apply style * Add num_logits_to_keep as arg in prepare_input_for_generation * Add support for Mistral * Revert models except llama and mistral * Fix default None value in _supports_num_logits_to_keep() * Fix dimension of dummy input * Add exception for prophetnet in _supports_num_logits_to_keep() * Update _supports_num_logits_to_keep() to use inspect.signature() * Add deprecation cycle + remove modification with pretraining_tp * Apply style * Add most used models * Apply style * Make `num_logits_to_keep` an int in all cases to remove if-else clause * Add compile check for the warning * Fix torch versions * style * Add gemma2 * Update warning version * Add comment about .float operations in generation utils * Add tests in GenerationTesterMixin and ModelTesterMixin * Fix batch size for assisted decoding in tests * fix small issues in test * refacor test * fix slicing removing dim issue * Add nemotron support (should fix check-copy issue in CIs) * Trigger new CIs * Trigger new CIs * Bump version * Bump version in TODO * Trigger CIs * remove blank space * Trigger CIs
* Only cast logits to float when computing loss Some misses from huggingface#31292 and huggingface#33902 * Move logits.float() into existing if labels is not None branch
What does this PR do?
This is the PR related to the discussion in #30860.
I followed was has been done in
Jamba
and added support for thenum_logits_to_keep
argument inforward()
. However, even if this argument isNone
, the logits will only be upcasted to float iflabels
are passed (in order to accurately compute the loss). Otherwise, the upcasting only happen in thegenerate()
functions.For now, I only modified
Llama
andMistral
, but if you agree on the changes I will add support for more models.Benchmarks
Here I provide some benchmarks of the peak memory usage. For each input size, I generated 10 additional tokens.
Of course, since for few additional tokens the memory peak scales only with the first forward pass (at least when computing the whole logits matrix), and that the first forward scales linearly with input size and batch size (with new attention algorithms), the gain is actually constant for all input sizes and generation methods (except for contrastive search, which artificially increase the batch size after the first forward, thus the memory usage is slightly different). However, I still provide results for all generation methods here for completeness.
Basically we get:
Llama3 8B -> MIND-BLOWING 3.62 memory usage reduction factor (due to large vocabulary)
Llama2 7B -> 1.17 reduction factor
Mistral 7B -> 1.32 reduction factor
Note that the memory reduction shown here is on top of whatever gains #30536 already provides for small new additional tokens, as I am comparing memory with the main transformers branch after it was merged. It integrates very nicely with that last PR, as the last one was providing most benefits when generating more tokens, and this one provides gains for small new number of tokens.
greedy.pdf
sample.pdf
beam sample.pdf
beam search.pdf
group beam search.pdf
contrastive search.pdf
Here is a link to the benchmark script: https://gist.github.com/Cyrilvallez/92f48e402aa2968c854a8128796f50c3
Who can review?
@ArthurZucker @gante Let me know what you think about the proposed changes!