Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Zamba2 #34517

Draft
wants to merge 76 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
acd25b7
First commit
pglorio Oct 24, 2024
70639b8
Finish model implementation
pglorio Oct 28, 2024
d111b98
First commit
pglorio Oct 24, 2024
8f36dba
Finish model implementation
pglorio Oct 28, 2024
f0c547c
Merge branch 'zamba2' of https://github.com/Zyphra/transformers_zamba…
pglorio Oct 29, 2024
700fbf0
Register zamba2
pglorio Oct 30, 2024
70a6021
generated modeling and configuration
pglorio Nov 4, 2024
88c4b26
Merge pull request #2 from Zyphra/main
pglorio Nov 5, 2024
685906a
generated modeling and configuration
pglorio Nov 5, 2024
4da8d5f
added hybrid cache
pglorio Nov 5, 2024
6b5a9be
fix attention_mask in mamba
pglorio Nov 5, 2024
248350d
dropped unused loras
pglorio Nov 5, 2024
d1d2c66
fix flash2
pglorio Nov 5, 2024
eb6063e
Merge pull request #3 from Zyphra/main
pglorio Nov 5, 2024
5f5d01e
config docstrings
Nov 6, 2024
c1b7647
fix config and fwd pass
pglorio Nov 7, 2024
979b99b
make fixup fixes
pglorio Nov 7, 2024
9d9b2eb
text_modeling_zamba2
pglorio Nov 9, 2024
3a457f5
Merge pull request #4 from Zyphra/main
pglorio Nov 9, 2024
549d4cb
small fixes
pglorio Nov 9, 2024
987bba9
make fixup fixes
pglorio Nov 11, 2024
ffc2a58
Merge pull request #5 from Zyphra/main
pglorio Nov 11, 2024
9adf85e
Fix modular model converter
pglorio Nov 11, 2024
904da4e
added inheritances in modular, renamed zamba cache
pglorio Nov 19, 2024
4725983
Merge pull request #6 from Zyphra/main
pglorio Nov 19, 2024
0be27d7
modular rebase
pglorio Nov 19, 2024
cc0c549
Rebase
pglorio Nov 19, 2024
ac77a09
new modular conversion
pglorio Nov 20, 2024
e59980e
fix generated modeling file
pglorio Nov 20, 2024
73a647a
fixed import for Zamba2RMSNormGated
pglorio Nov 20, 2024
c2b72a5
modular file cleanup
pglorio Nov 21, 2024
0eb39a5
rebase
pglorio Nov 21, 2024
10a0b1e
make fixup and model tests
pglorio Nov 21, 2024
0270667
dropped inheritance for Zamba2PreTrainedModel
pglorio Nov 23, 2024
189c8c5
make fixup and unit tests
pglorio Nov 23, 2024
fa5f79e
Add inheritance of rope from GemmaRotaryEmbedding
pglorio Dec 5, 2024
8079ae0
moved rope to model init
pglorio Dec 5, 2024
d6206eb
drop del self.self_attn and del self.feed_forward
pglorio Dec 5, 2024
f832699
Rebase onto upstream
pglorio Dec 5, 2024
cf613b7
fix tests
pglorio Dec 5, 2024
337faed
renamed lora -> adapter
pglorio Dec 7, 2024
f1b31a1
rewrote adapter implementation
pglorio Dec 7, 2024
8925c15
rebase
pglorio Dec 7, 2024
11fdd47
fixed tests
pglorio Dec 7, 2024
02dd042
Merge branch 'main' into zamba2
pglorio Dec 18, 2024
5d0a5d4
Fix torch_forward in mamba2 layer
pglorio Dec 19, 2024
ef055c9
Fix torch_forward in mamba2 layer
pglorio Dec 19, 2024
b993a78
Fix torch_forward in mamba2 layer
pglorio Dec 19, 2024
bf93251
Dropped adapter in-place sum
pglorio Dec 19, 2024
99708af
removed rope from attention init
pglorio Dec 19, 2024
d9b4a50
updated rope
pglorio Dec 19, 2024
095d853
created get_layers method
pglorio Dec 19, 2024
10ebad5
rebase
pglorio Dec 20, 2024
99e343e
make fixup fix
pglorio Dec 20, 2024
4e40975
make fixup fixes
pglorio Dec 20, 2024
61bb32f
make fixup fixes
pglorio Dec 20, 2024
bb9b24b
fix merge conflicts
pglorio Jan 7, 2025
cb90bb4
update to new attention standard
pglorio Jan 13, 2025
8ed701e
fixes for merge
pglorio Jan 13, 2025
1dbc8c7
update to new attention standard
pglorio Jan 13, 2025
f24e452
make fixup fixes
pglorio Jan 13, 2025
676f862
rebase
pglorio Jan 16, 2025
2b29338
minor fixes
pglorio Jan 16, 2025
b212cb2
cache_position
pglorio Jan 16, 2025
1e3b51e
removed cache_position postion_ids use_cache
pglorio Jan 16, 2025
5ace701
remove config from modular
pglorio Jan 16, 2025
535b631
removed config from modular (2)
pglorio Jan 16, 2025
5a16aa9
rebase
pglorio Jan 16, 2025
1c92266
import apply_rotary_pos_emb from llama
pglorio Jan 16, 2025
99bde93
fixed rope_kwargs
pglorio Jan 16, 2025
baf2ed3
Instantiate cache in Zamba2Model
pglorio Jan 16, 2025
9afb57e
fix cache
pglorio Jan 17, 2025
d1687f9
fix @slow decorator
pglorio Jan 17, 2025
4299889
rebase
pglorio Jan 20, 2025
a0545bf
rebase
pglorio Jan 21, 2025
903f6dc
small fix in modular file
pglorio Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ Flax), PyTorch, and/or TensorFlow.
| [YOLOS](model_doc/yolos) | ✅ | ❌ | ❌ |
| [YOSO](model_doc/yoso) | ✅ | ❌ | ❌ |
| [Zamba](model_doc/zamba) | ✅ | ❌ | ❌ |
| [Zamba2](model_doc/zamba2) | ✅ | ❌ | ❌ |
| [ZoeDepth](model_doc/zoedepth) | ✅ | ❌ | ❌ |

<!-- End table-->
93 changes: 93 additions & 0 deletions docs/source/en/model_doc/zamba2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->
# Zamba2

Zamba2 is a large language model (LLM) trained by Zyphra, and made available under an Apache 2.0 license. Please see the [Zyphra Hugging Face](https://huggingface.co/collections/zyphra/) repository for model weights.

This model was contributed by [pglo](https://huggingface.co/pglo).


## Model details

Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B are hybrid models combining state-space models (Specifically [Mamba](https://github.com/state-spaces/mamba)) and transformer, and were trained using next-token prediction. Zamba2 uses shared transformer layers after every 6 mamba blocks. It uses the [Mistral v0.1 tokenizer](https://huggingface.co/mistralai/Mistral-7B-v0.1). We came to this architecture after a series of ablations at small scales. Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B were pre-trained on 2T and 3T tokens, respectively.

<img src=https://github.com/user-attachments/assets/c2cff209-b901-483c-87aa-774b82a0769f width=30% height=40% />

## Quick start


### Presequities

Zamba2 requires you use `transformers` version 4.46.0 or higher:
```bash
pip install transformers>=4.46.0
```
Comment on lines +34 to +37
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Zamba2 requires you use `transformers` version 4.46.0 or higher:
```bash
pip install transformers>=4.46.0
```
Zamba2 requires you use `transformers` version 4.48.0 or higher:
```bash
pip install transformers>=4.48.0


## Inference

```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16)

input_text = "What factors contributed to the fall of the Roman Empire?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))
```


## Model card

The model cards can be found at:
* [Zamba2-1.2B](https://huggingface.co/Zyphra/Zamba2-1.2B)
* [Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B)
* [Zamba2-7B](https://huggingface.co/Zyphra/Zamba2-7B)


## Issues
For issues with model output, or community discussion, please use the Hugging Face community [forum](https://huggingface.co/Zyphra/Zamba2-7B/discussions)


## License

The model weights are open-sourced via an Apache 2.0 license.


## Zamba2Config

[[autodoc]] Zamba2Config


## Zamba2Model

[[autodoc]] Zamba2Model
- forward


## Zamba2ForCausalLM

[[autodoc]] Zamba2ForCausalLM
- forward


## Zamba2ForSequenceClassification

[[autodoc]] transformers.Zamba2ForSequenceClassification
- forward
3 changes: 2 additions & 1 deletion docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2)

You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.

Expand Down Expand Up @@ -304,7 +305,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)

* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2)
<Tip>

FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle `fp32` models.
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@
"models.yolos": ["YolosConfig"],
"models.yoso": ["YosoConfig"],
"models.zamba": ["ZambaConfig"],
"models.zamba2": ["Zamba2Config"],
"models.zoedepth": ["ZoeDepthConfig"],
"onnx": [],
"pipelines": [
Expand Down Expand Up @@ -3804,6 +3805,14 @@
"ZambaPreTrainedModel",
]
)
_import_structure["models.zamba2"].extend(
[
"Zamba2ForCausalLM",
"Zamba2ForSequenceClassification",
"Zamba2Model",
"Zamba2PreTrainedModel",
]
)
_import_structure["models.zoedepth"].extend(
[
"ZoeDepthForDepthEstimation",
Expand Down Expand Up @@ -5781,6 +5790,7 @@
from .models.yolos import YolosConfig
from .models.yoso import YosoConfig
from .models.zamba import ZambaConfig
from .models.zamba2 import Zamba2Config
from .models.zoedepth import ZoeDepthConfig

# Pipelines
Expand Down Expand Up @@ -8209,6 +8219,12 @@
ZambaModel,
ZambaPreTrainedModel,
)
from .models.zamba2 import (
Zamba2ForCausalLM,
Zamba2ForSequenceClassification,
Zamba2Model,
Zamba2PreTrainedModel,
)
from .models.zoedepth import (
ZoeDepthForDepthEstimation,
ZoeDepthPreTrainedModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,5 +285,6 @@
yolos,
yoso,
zamba,
zamba2,
zoedepth,
)
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@
("yolos", "YolosConfig"),
("yoso", "YosoConfig"),
("zamba", "ZambaConfig"),
("zamba2", "Zamba2Config"),
("zoedepth", "ZoeDepthConfig"),
]
)
Expand Down Expand Up @@ -637,6 +638,7 @@
("yolos", "YOLOS"),
("yoso", "YOSO"),
("zamba", "Zamba"),
("zamba2", "Zamba2"),
("zoedepth", "ZoeDepth"),
]
)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@
("yolos", "YolosModel"),
("yoso", "YosoModel"),
("zamba", "ZambaModel"),
("zamba2", "Zamba2Model"),
]
)

Expand Down Expand Up @@ -552,6 +553,7 @@
("xlnet", "XLNetLMHeadModel"),
("xmod", "XmodForCausalLM"),
("zamba", "ZambaForCausalLM"),
("zamba2", "Zamba2ForCausalLM"),
]
)

Expand Down Expand Up @@ -1008,6 +1010,7 @@
("xmod", "XmodForSequenceClassification"),
("yoso", "YosoForSequenceClassification"),
("zamba", "ZambaForSequenceClassification"),
("zamba2", "Zamba2ForSequenceClassification"),
]
)

Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,13 @@
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"zamba2",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
]
)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def forward(
return outputs


class HybridLayer(nn.Module):
class ZambaHybridLayer(nn.Module):
def __init__(self, shared_transf: ZambaAttentionDecoderLayer, linear: nn.Linear, mamba: ZambaMambaDecoderLayer):
super().__init__()
self.shared_transf = shared_transf
Expand Down Expand Up @@ -1201,7 +1201,7 @@ def __init__(self, config: ZambaConfig):
"shared_transf.pre_ff_layernorm.weight",
]
self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]]
layers.append(HybridLayer(block, next(linear_layers), next(mamba_layers)))
layers.append(ZambaHybridLayer(block, next(linear_layers), next(mamba_layers)))
else:
layers.append(next(mamba_layers))
self.layers = nn.ModuleList(layers)
Expand Down
57 changes: 57 additions & 0 deletions src/transformers/models/zamba2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available


_import_structure = {
"configuration_zamba2": ["Zamba2Config"],
}


try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_zamba2"] = [
"Zamba2ForCausalLM",
"Zamba2ForSequenceClassification",
"Zamba2Model",
"Zamba2PreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_zamba2 import Zamba2Config

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_zamba2 import (
Zamba2ForCausalLM,
Zamba2ForSequenceClassification,
Zamba2Model,
Zamba2PreTrainedModel,
)


else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
pglorio marked this conversation as resolved.
Show resolved Hide resolved
Loading