Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve greedy search memory usage #32895

Merged
merged 1 commit into from
Aug 22, 2024

Conversation

regisss
Copy link
Contributor

@regisss regisss commented Aug 20, 2024

What does this PR do?

When doing greedy search, inputs go through _expand_inputs_for_generation here where they are expanded (see here). As the expand size is always 1 in the case of greedy search, torch.repeat_interleave do not modify the inputs. However, it does increase the memory usage as the input to torch.repeat_interleave is cloned.

Here is a code snippet to check this behaviour:

import torch

a = torch.ones(1000, 1000, 1000, device="cuda")
print(torch.cuda.max_memory_allocated())

expand_size = 1
a = a.repeat_interleave(expand_size, dim=0)
print(torch.cuda.max_memory_allocated())

which returns

4000000000
8000000000

Thus, if the expand size is 1, we can return the model inputs before calling torch.repeat_interleave. That's the change introduced in this PR.

More context in this Slack thread: https://huggingface.slack.com/archives/C01N44FJDHT/p1723827436938589

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@amyeroberts
Copy link
Collaborator

cc @gante @zucchini-nlp

@regisss
Copy link
Contributor Author

regisss commented Aug 20, 2024

CI failed but it doesn't seem to be related to this PR

@regisss regisss marked this pull request as ready for review August 20, 2024 09:45
@regisss regisss requested a review from gante August 20, 2024 09:45
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, interesting finding! Thanks for handling

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix 🙏

(CI is failing for reasons of our knowledge, will take care of rebasing and merging when the root cause is fixed cc @amyeroberts )

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks to adding this improvement!

@gante gante force-pushed the enhance_greedy_search_memory branch from c0169a1 to 77d8384 Compare August 22, 2024 12:08
@gante gante force-pushed the enhance_greedy_search_memory branch from 77d8384 to 23f74a1 Compare August 22, 2024 13:18
@gante gante merged commit 99d67f1 into huggingface:main Aug 22, 2024
21 checks passed
@regisss regisss deleted the enhance_greedy_search_memory branch August 22, 2024 16:37
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
Do not call torch.repeat_interleave if expand_size is 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants