Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mamba2] Move dt calculations to kernel #33520

Merged
merged 3 commits into from
Sep 19, 2024

Conversation

vasqu
Copy link
Contributor

@vasqu vasqu commented Sep 16, 2024

What does this PR do?

Moves calculations of time_step to the kernel instead of using torch's softplus directly. It seems like it influences the end result even more than expected, see fla-org/flash-linear-attention#63 for reference (2. point made there).

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @molbap

@vasqu
Copy link
Contributor Author

vasqu commented Sep 17, 2024

Maybe for clarification the referenced repo's mamba2 is mostly a mirror of this one.

@molbap
Copy link
Contributor

molbap commented Sep 17, 2024

Thanks for the PR! Very interesting, not opposed but what exactly changes between the in-kernel and out-of-kernel computation results? Train mode vs eval mode yield different results? In that case we should also add it as a minimal test specific for mamba2!

@vasqu
Copy link
Contributor Author

vasqu commented Sep 17, 2024

Just seems like that it's more aligned with the completely fused kernel then, or rather the numerical differences are minimized. Tbh it's a huge nit in and of itself :D

I can add a slow test that requires a GPU and the kernels; would be a copy paste of the test in the repo.

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, just needs the copy-pasted test indeed (credit to the other repo) for future code inspectors!

@molbap molbap requested a review from amyeroberts September 17, 2024 11:19
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! thanks for updating

Final step is running the slow tests for the model before merge, as this might have affected the integration tests. Could you push an empty commit with the message [run_slow] mamba?

@vasqu
Copy link
Contributor Author

vasqu commented Sep 17, 2024

I'll write the test tomorrow and ping you then; I don't have much time this evening.

@vasqu
Copy link
Contributor Author

vasqu commented Sep 18, 2024

Would be ready but ig waiting for #33560 to be merged seems appropriate for now 👀
cc @amyeroberts

@vasqu
Copy link
Contributor Author

vasqu commented Sep 18, 2024

Does it need a new decorator to check if the kernels are available or is it directly included in the slow runs? Might need an updated image if that's not the case. @molbap

@molbap
Copy link
Contributor

molbap commented Sep 19, 2024

Nope, there's no such decorator! Slow run ought to do it :)

@vasqu vasqu force-pushed the mamba2-fast-inference-path-fix branch from e3cfa8a to ff0cb7c Compare September 19, 2024 15:36
@vasqu
Copy link
Contributor Author

vasqu commented Sep 19, 2024

Good to go now :) #33567 is to be expected
cc @amyeroberts (sry for the constant pinging)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@amyeroberts
Copy link
Collaborator

@vasqu Pings are always welcome - it makes sure things aren't lost in the notification inbox :)

@amyeroberts amyeroberts merged commit b50ff59 into huggingface:main Sep 19, 2024
16 of 17 checks passed
@vasqu vasqu deleted the mamba2-fast-inference-path-fix branch September 19, 2024 16:44
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* use kernel for dt calculations

* add small test

* [run-slow] mamba2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants