-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Add Mamba
] Adds support for the Mamba
models
#28094
Conversation
Oups! Still planned but KVCache will come first |
Alright I am picking this back up! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hey, it's great to see that mamba is being integrated in Transformers! Just wondering, is there a timeline or ETA for this PR? Thanks so much. |
I want to merge it asap so probably max end of next week! |
This draft PR is a work in progress implementation of the mamba model. This PR currently loads weights, and produces correct logits after a single pass. This PR still needs to correctly integrate this model so it produces tokens as expected, and apply optimization to avoid all copies during runtime/unnecessary operations. #### Helpful resources [Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752) https://github.com/johnma2006/mamba-minimal https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs huggingface/transformers#28094 Notes: this dev work is currently targeting `state-spaces/mamba-130m`, so if you want to test please use that model. Additionally when starting the router the prefill needs to be limited: `cargo run -- --max-batch-prefill-tokens 768 --max-input-length 768` ## Update / Current State Integration tests have been added and basic functionality such as model loading is supported. ```bash cd integration-tests pytest -vv models/test_fused_kernel_mamba.py ``` - [x] add tests - [x] load model - [x] make simple request - [ ] resolve warmup issue - [ ] resolve output issues fetching models tested during dev ```bash text-generation-server download-weights state-spaces/mamba-130m text-generation-server download-weights state-spaces/mamba-1.4b text-generation-server download-weights state-spaces/mamba-2.8b ``` The server can be run ```bash cd server MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b ``` router ```bash cargo run ``` make a request ```bash curl -s localhost:3000/generate \ -X POST \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ -H 'Content-Type: application/json' | jq ``` response ```json { "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data." } ``` --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Got side tracked, done with caching issues! |
Done! 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me, please add the example that you have in the PR description somewhere in the documentation as well. The current examples don't really show how to use the model imo.
Co-authored-by: Lysandre Debut <hi@lysand.re>
ee6a9c2
to
f963e38
Compare
@ArthurZucker Thank you for this amazing addition. Are there any plans to add something equivalent to |
not sure why would you need it? |
|
There is no notion of causal mask or masking in |
Hi. There is a problem in the Trainer where the
Full disclosure, I encountered this "bug" in my own @ArthurZucker tagging you :) |
Feel free to open a PR for the fix! 🤗 |
Also |
This draft PR is a work in progress implementation of the mamba model. This PR currently loads weights, and produces correct logits after a single pass. This PR still needs to correctly integrate this model so it produces tokens as expected, and apply optimization to avoid all copies during runtime/unnecessary operations. #### Helpful resources [Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752) https://github.com/johnma2006/mamba-minimal https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs huggingface/transformers#28094 Notes: this dev work is currently targeting `state-spaces/mamba-130m`, so if you want to test please use that model. Additionally when starting the router the prefill needs to be limited: `cargo run -- --max-batch-prefill-tokens 768 --max-input-length 768` ## Update / Current State Integration tests have been added and basic functionality such as model loading is supported. ```bash cd integration-tests pytest -vv models/test_fused_kernel_mamba.py ``` - [x] add tests - [x] load model - [x] make simple request - [ ] resolve warmup issue - [ ] resolve output issues fetching models tested during dev ```bash text-generation-server download-weights state-spaces/mamba-130m text-generation-server download-weights state-spaces/mamba-1.4b text-generation-server download-weights state-spaces/mamba-2.8b ``` The server can be run ```bash cd server MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b ``` router ```bash cargo run ``` make a request ```bash curl -s localhost:3000/generate \ -X POST \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ -H 'Content-Type: application/json' | jq ``` response ```json { "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data." } ``` --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This draft PR is a work in progress implementation of the mamba model. This PR currently loads weights, and produces correct logits after a single pass. This PR still needs to correctly integrate this model so it produces tokens as expected, and apply optimization to avoid all copies during runtime/unnecessary operations. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752) https://github.com/johnma2006/mamba-minimal https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs huggingface/transformers#28094 Notes: this dev work is currently targeting `state-spaces/mamba-130m`, so if you want to test please use that model. Additionally when starting the router the prefill needs to be limited: `cargo run -- --max-batch-prefill-tokens 768 --max-input-length 768` Integration tests have been added and basic functionality such as model loading is supported. ```bash cd integration-tests pytest -vv models/test_fused_kernel_mamba.py ``` - [x] add tests - [x] load model - [x] make simple request - [ ] resolve warmup issue - [ ] resolve output issues fetching models tested during dev ```bash text-generation-server download-weights state-spaces/mamba-130m text-generation-server download-weights state-spaces/mamba-1.4b text-generation-server download-weights state-spaces/mamba-2.8b ``` The server can be run ```bash cd server MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b ``` router ```bash cargo run ``` make a request ```bash curl -s localhost:3000/generate \ -X POST \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ -H 'Content-Type: application/json' | jq ``` response ```json { "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data." } ``` --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This draft PR is a work in progress implementation of the mamba model. This PR currently loads weights, and produces correct logits after a single pass. This PR still needs to correctly integrate this model so it produces tokens as expected, and apply optimization to avoid all copies during runtime/unnecessary operations. #### Helpful resources [Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752) https://github.com/johnma2006/mamba-minimal https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs huggingface/transformers#28094 Notes: this dev work is currently targeting `state-spaces/mamba-130m`, so if you want to test please use that model. Additionally when starting the router the prefill needs to be limited: `cargo run -- --max-batch-prefill-tokens 768 --max-input-length 768` ## Update / Current State Integration tests have been added and basic functionality such as model loading is supported. ```bash cd integration-tests pytest -vv models/test_fused_kernel_mamba.py ``` - [x] add tests - [x] load model - [x] make simple request - [ ] resolve warmup issue - [ ] resolve output issues fetching models tested during dev ```bash text-generation-server download-weights state-spaces/mamba-130m text-generation-server download-weights state-spaces/mamba-1.4b text-generation-server download-weights state-spaces/mamba-2.8b ``` The server can be run ```bash cd server MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b ``` router ```bash cargo run ``` make a request ```bash curl -s localhost:3000/generate \ -X POST \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ -H 'Content-Type: application/json' | jq ``` response ```json { "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data." } ``` --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
What does this PR do?
Feel free to try this:
Peft training that works, thanks @younesbelkada : Results: https://huggingface.co/ArthurZ/mamba-2.4b-english-quotes
pink: 360m, full fine-tune
![image](https://private-user-images.githubusercontent.com/48595927/308819927-9e5f5726-eb39-4aaf-859e-2f9b8a02ffb3.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk0NDE3ODAsIm5iZiI6MTczOTQ0MTQ4MCwicGF0aCI6Ii80ODU5NTkyNy8zMDg4MTk5MjctOWU1ZjU3MjYtZWIzOS00YWFmLTg1OWUtMmY5YjhhMDJmZmIzLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTMlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjEzVDEwMTEyMFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTYzNjc2YTJhYmNkMTFjMTE3OTg4Zjk4YjAxNDFjZTdmN2M5ZTZkMGY2ODlmYmQ4NDIxZGNmMDhhZTdjM2RjMWQmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.qLGW-3U-g97VZdHMuy-OYRZ5V9L35xbNp7l-AkDhgZk)
bleue: 2.8b peft
red: 2.8b peft
fixes #28086