Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Add Mamba] Adds support for the Mamba models #28094

Merged
merged 123 commits into from
Mar 5, 2024
Merged

[Add Mamba] Adds support for the Mamba models #28094

merged 123 commits into from
Mar 5, 2024

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Dec 16, 2023

What does this PR do?

  • Implement cpu ops
  • Add integration tests
  • Implement fast path
  • check training + peft
  • convert all checkpoints: just need to make sure config is correct

Feel free to try this:

from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("ArthurZ/mamba-130m")
tokenizer.pad_token = tokenizer.eos_token

model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", vocab_size=50280, num_hidden_layers=24, torch_dtype=torch.float32)
model.config.use_cache = True
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))

Peft training that works, thanks @younesbelkada : Results: https://huggingface.co/ArthurZ/mamba-2.4b-english-quotes

from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
model_id = "ArthurZ/mamba-2.8b"
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token ="<s>")
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)
lora_config =  LoraConfig(
        r=8,
        target_modules="all-linear",
        task_type="CAUSAL_LM",
        bias="none"
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
)
trainer.train()

pink: 360m, full fine-tune
bleue: 2.8b peft
red: 2.8b peft
image

fixes #28086

@ArthurZucker ArthurZucker linked an issue Dec 16, 2023 that may be closed by this pull request
2 tasks
@huggingface huggingface deleted a comment from github-actions bot Jan 16, 2024
@ArthurZucker
Copy link
Collaborator Author

Oups! Still planned but KVCache will come first

@ArthurZucker
Copy link
Collaborator Author

Alright I am picking this back up!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@apoorvkh
Copy link
Contributor

apoorvkh commented Feb 1, 2024

Hey, it's great to see that mamba is being integrated in Transformers! Just wondering, is there a timeline or ETA for this PR? Thanks so much.

@ArthurZucker
Copy link
Collaborator Author

I want to merge it asap so probably max end of next week!

Narsil added a commit to huggingface/text-generation-inference that referenced this pull request Feb 8, 2024
This draft PR is a work in progress implementation of the mamba model.
This PR currently loads weights, and produces correct logits after a
single pass.

This PR still needs to correctly integrate this model so it produces
tokens as expected, and apply optimization to avoid all copies during
runtime/unnecessary operations.

#### Helpful resources
[Mamba: Linear-Time Sequence Modeling with Selective State Spaces
(Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752)
https://github.com/johnma2006/mamba-minimal

https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs
huggingface/transformers#28094

Notes: this dev work is currently targeting `state-spaces/mamba-130m`,
so if you want to test please use that model. Additionally when starting
the router the prefill needs to be limited: `cargo run --
--max-batch-prefill-tokens 768 --max-input-length 768`


## Update / Current State

Integration tests have been added and basic functionality such as model
loading is supported.

```bash
cd integration-tests
pytest -vv models/test_fused_kernel_mamba.py
```
- [x] add tests
- [x] load model
- [x] make simple request 
- [ ] resolve warmup issue
- [ ] resolve output issues


fetching models tested during dev
```bash
text-generation-server download-weights state-spaces/mamba-130m
text-generation-server download-weights state-spaces/mamba-1.4b
text-generation-server download-weights state-spaces/mamba-2.8b
```

The server can be run 
```bash
cd server
 MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b
```

router
```bash
cargo run
```

make a request
```bash
curl -s localhost:3000/generate \
    -X POST \
    -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
    -H 'Content-Type: application/json' | jq
```

response
```json
{
  "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data."
}
```

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
@ArthurZucker
Copy link
Collaborator Author

Got side tracked, done with caching issues!
Was meditating the stateful vs stateless approach we want to take to support torch compile and graphs without the extra complexity similarly to #27931.
It was advised that for mamba, cache should work in a stateless manner

@ArthurZucker ArthurZucker requested a review from LysandreJik March 5, 2024 01:38
@ArthurZucker
Copy link
Collaborator Author

Done! 🤗

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, please add the example that you have in the PR description somewhere in the documentation as well. The current examples don't really show how to use the model imo.

docs/source/en/model_doc/mamba.md Outdated Show resolved Hide resolved
Co-authored-by: Lysandre Debut <hi@lysand.re>
@ArthurZucker ArthurZucker force-pushed the add-mamba branch 2 times, most recently from ee6a9c2 to f963e38 Compare March 5, 2024 10:38
@abdulfatir
Copy link

@ArthurZucker Thank you for this amazing addition. Are there any plans to add something equivalent to attention_mask for Mamba?

@ArthurZucker
Copy link
Collaborator Author

not sure why would you need it?

@abdulfatir
Copy link

  • For batched inference with inputs of different length.
  • For pretraining with different masking schemes than a causal mask.

@ArthurZucker
Copy link
Collaborator Author

There is no notion of causal mask or masking in mamba as it is not based on attention. That's why I am not sure I follow

@lkurlandski
Copy link

Hi.

There is a problem in the Trainer where the logits returned by Trainer.prediction_step will return a tuple[Tensor, MambaCache] object. This causes a host of issues when accelerate tries to move the logits on the same device, change datatypes, etc. The solution is to set the "keys_to_ignore_at_inference" field of the associated Config class to include "cache_params". The change is simple:

class MambaConfig:
    keys_to_ignore_at_inference = ["cache_params"]

Full disclosure, I encountered this "bug" in my own MambaForSequenceClassification class, not a module from transformers itself and I have not really tested this thoroughly to see if it is present in the classes from transformers.

@ArthurZucker tagging you :)

@ArthurZucker
Copy link
Collaborator Author

Feel free to open a PR for the fix! 🤗

@ArthurZucker
Copy link
Collaborator Author

Also use_cache=False should prevent this as well no?

cr313 added a commit to cr313/text-generation-inference-load-test that referenced this pull request Apr 19, 2024
This draft PR is a work in progress implementation of the mamba model.
This PR currently loads weights, and produces correct logits after a
single pass.

This PR still needs to correctly integrate this model so it produces
tokens as expected, and apply optimization to avoid all copies during
runtime/unnecessary operations.

#### Helpful resources
[Mamba: Linear-Time Sequence Modeling with Selective State Spaces
(Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752)
https://github.com/johnma2006/mamba-minimal

https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs
huggingface/transformers#28094

Notes: this dev work is currently targeting `state-spaces/mamba-130m`,
so if you want to test please use that model. Additionally when starting
the router the prefill needs to be limited: `cargo run --
--max-batch-prefill-tokens 768 --max-input-length 768`


## Update / Current State

Integration tests have been added and basic functionality such as model
loading is supported.

```bash
cd integration-tests
pytest -vv models/test_fused_kernel_mamba.py
```
- [x] add tests
- [x] load model
- [x] make simple request 
- [ ] resolve warmup issue
- [ ] resolve output issues


fetching models tested during dev
```bash
text-generation-server download-weights state-spaces/mamba-130m
text-generation-server download-weights state-spaces/mamba-1.4b
text-generation-server download-weights state-spaces/mamba-2.8b
```

The server can be run 
```bash
cd server
 MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b
```

router
```bash
cargo run
```

make a request
```bash
curl -s localhost:3000/generate \
    -X POST \
    -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
    -H 'Content-Type: application/json' | jq
```

response
```json
{
  "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data."
}
```

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
kdamaszk pushed a commit to kdamaszk/tgi-gaudi that referenced this pull request Apr 29, 2024
This draft PR is a work in progress implementation of the mamba model.
This PR currently loads weights, and produces correct logits after a
single pass.

This PR still needs to correctly integrate this model so it produces
tokens as expected, and apply optimization to avoid all copies during
runtime/unnecessary operations.

[Mamba: Linear-Time Sequence Modeling with Selective State Spaces
(Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752)
https://github.com/johnma2006/mamba-minimal

https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs
huggingface/transformers#28094

Notes: this dev work is currently targeting `state-spaces/mamba-130m`,
so if you want to test please use that model. Additionally when starting
the router the prefill needs to be limited: `cargo run --
--max-batch-prefill-tokens 768 --max-input-length 768`

Integration tests have been added and basic functionality such as model
loading is supported.

```bash
cd integration-tests
pytest -vv models/test_fused_kernel_mamba.py
```
- [x] add tests
- [x] load model
- [x] make simple request
- [ ] resolve warmup issue
- [ ] resolve output issues

fetching models tested during dev
```bash
text-generation-server download-weights state-spaces/mamba-130m
text-generation-server download-weights state-spaces/mamba-1.4b
text-generation-server download-weights state-spaces/mamba-2.8b
```

The server can be run
```bash
cd server
 MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b
```

router
```bash
cargo run
```

make a request
```bash
curl -s localhost:3000/generate \
    -X POST \
    -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
    -H 'Content-Type: application/json' | jq
```

response
```json
{
  "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data."
}
```

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
alfredgui2 pushed a commit to mlsys-io/kv.run that referenced this pull request Jul 6, 2024
This draft PR is a work in progress implementation of the mamba model.
This PR currently loads weights, and produces correct logits after a
single pass.

This PR still needs to correctly integrate this model so it produces
tokens as expected, and apply optimization to avoid all copies during
runtime/unnecessary operations.

#### Helpful resources
[Mamba: Linear-Time Sequence Modeling with Selective State Spaces
(Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752)
https://github.com/johnma2006/mamba-minimal

https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs
huggingface/transformers#28094

Notes: this dev work is currently targeting `state-spaces/mamba-130m`,
so if you want to test please use that model. Additionally when starting
the router the prefill needs to be limited: `cargo run --
--max-batch-prefill-tokens 768 --max-input-length 768`


## Update / Current State

Integration tests have been added and basic functionality such as model
loading is supported.

```bash
cd integration-tests
pytest -vv models/test_fused_kernel_mamba.py
```
- [x] add tests
- [x] load model
- [x] make simple request 
- [ ] resolve warmup issue
- [ ] resolve output issues


fetching models tested during dev
```bash
text-generation-server download-weights state-spaces/mamba-130m
text-generation-server download-weights state-spaces/mamba-1.4b
text-generation-server download-weights state-spaces/mamba-2.8b
```

The server can be run 
```bash
cd server
 MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b
```

router
```bash
cargo run
```

make a request
```bash
curl -s localhost:3000/generate \
    -X POST \
    -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
    -H 'Content-Type: application/json' | jq
```

response
```json
{
  "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data."
}
```

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add [Mamba] model
7 participants