Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Zamba2 #34517

Draft
wants to merge 76 commits into
base: main
Choose a base branch
from
Draft

Add Zamba2 #34517

wants to merge 76 commits into from

Conversation

pglorio
Copy link
Contributor

@pglorio pglorio commented Oct 30, 2024

What does this PR do?

Please include support for Zamba2 architecture created by Zyphra Technologies.

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker

@pglorio pglorio marked this pull request as draft October 30, 2024 17:57
@pglorio
Copy link
Contributor Author

pglorio commented Nov 11, 2024

Hey @Arthur,

Thank you again for your help in getting Zamba2 into transformers! The PR is now finally ready to be reviewed. I added the documentation and all unit tests pass, including slow tests.

A few remarks, mostly related to modular transformers:

  1. To generate modeling and configuration I used utils/modular_model_converter.py from a previous commit because the most recent version of this script that followed from a large refactoring produces an error that I was not able to fix:
Converting src/transformers/models/zamba2/modular_zamba2.py to a single model single file format
Traceback (most recent call last):
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1510, in <module>
    converted_files = convert_modular_file(file_name, args.old_model_name, args.new_model_name)
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1447, in convert_modular_file
    for file, module in create_modules(cst_transformers).items():
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1387, in create_modules
    nodes_to_add, file_type, new_imports = get_class_node_and_dependencies(modular_mapper, class_name, node, files)
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1337, in get_class_node_and_dependencies
    new_node_dependencies, new_imports = check_dependencies_and_create_import_node(
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1283, in check_dependencies_and_create_import_node
    class_dependencies = {dep for dep in new_dependencies if m.matches(mapper.global_nodes[dep], m.ClassDef())}
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1283, in <setcomp>
    class_dependencies = {dep for dep in new_dependencies if m.matches(mapper.global_nodes[dep], m.ClassDef())}
KeyError: 'Zamba2Config'

I carefully compared Zamba2Config with classes of other models that also use modular (such as Gemma2Config) and they appear to have consistent format. Relatedly, the utils/modular_model_converter.py in the current PR (path) is the version from the previous commit mentioned above.

  1. After running utils/modular_model_converter.py, the modeling and configuration files generated contain unintended code that I had to update. All these modifications are in this commit. In particular, the produced modeling file contains Zamba2DynamicCache, which is the correct cache of Zamba2 as well as HybridMambaAttentionDynamicCache, which is the cache of Zamba and is not relevant to Zamba2, so I deleted HybridMambaAttentionDynamicCache and related references.

  2. I ran make fixup and all zamba-related tests pass, with the exception of python utils/check_modular_conversion.py. This test doesn't pass due to the modifications mentioned in the previous point.

  3. I slightly edited the Zamba2MambaMixer compared to the original Mamba2Mixer of mamba2, the main difference is that I added these lines, which was necessary to appropriately process the mamba2 cache (note this step already existed in the torch forward in these lines).

Looking forward to your feedback. Thanks so much!

@pglorio pglorio mentioned this pull request Jan 7, 2025
5 tasks
@pglorio
Copy link
Contributor Author

pglorio commented Jan 14, 2025

Hi @Cyrilvallez and @ArthurZucker,

I updated the attention forward to the new standard of transformers here and here.

I ran all final tests, including @slow tests, and everything appears to pass!

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work for the refactor! Almost ready, left some final comments but overall quite nice! 🤗

src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
Comment on lines 1415 to 1417
"ZambaModelTester",
"Zamba2ModelTester",
"RwkvModelTester",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ydshieh here to ensure this change is necessary, as I'm not familiar with this new part!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh for context, when running this test the config of the model is forced to have num_hidden_layers=1 but other parameters of the config are not updated accordingly so when the model is initialized it errors out as these params are not consistently updated. It's probably also the reason why Zamba was added to this list I imagine.

tests/models/zamba2/test_modeling_zamba2.py Outdated Show resolved Hide resolved
tests/models/zamba2/test_modeling_zamba2.py Outdated Show resolved Hide resolved
tests/models/zamba2/test_modeling_zamba2.py Show resolved Hide resolved
@pglorio
Copy link
Contributor Author

pglorio commented Jan 16, 2025

Thank you @Cyrilvallez for the review. I addressed the comments above, although there are a couple of pending points.

All zamba-related tests appear to pass.

@pglorio
Copy link
Contributor Author

pglorio commented Jan 17, 2025

Hello @Cyrilvallez, I ran all model tests on two GPUs and after a couple of minor fixes everything appears to work now. I'm skipping this test as it gives an error related to mamba2 kernels. I indeed verified that mamba2 skips that test here.

Separately, when running utils/check_modular_conversion.py I get the following error:

Differences found between the generated code and src/transformers/models/zamba2/modeling_zamba2.py:

   1 --- src/transformers/models/zamba2/modeling_zamba2.py_generated
   2 +++ src/transformers/models/zamba2/modeling_zamba2.py
   3 @@ -313,6 +313,13 @@
   4      return attn_output, attn_weights
   5  
   6  
   7 +def rotate_half(x):
   8 +    """Rotates half the hidden dims of the input."""
   9 +    x1 = x[..., : x.shape[-1] // 2]
  10 +    x2 = x[..., x.shape[-1] // 2 :]
  11 +    return torch.cat((-x2, x1), dim=-1)
  12 +
  13 +
  14  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  15      """Applies Rotary Position Embedding to the query and key tensors.
  16  
  17 @@ -338,13 +345,6 @@
  18      q_embed = (q * cos) + (rotate_half(q) * sin)
  19      k_embed = (k * cos) + (rotate_half(k) * sin)
  20      return q_embed, k_embed
  21 -
  22 -
  23 -def rotate_half(x):
  24 -    """Rotates half the hidden dims of the input."""
  25 -    x1 = x[..., : x.shape[-1] // 2]
  26 -    x2 = x[..., x.shape[-1] // 2 :]
  27 -    return torch.cat((-x2, x1), dim=-1)

which I was not getting before despite this part was identical.

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Let's just wait for #35795 which will get rid of the CI failure for modular conversion! Sorry about that, and thanks for being so patient with us 🙏🙏🤗
Great work!

@pglorio
Copy link
Contributor Author

pglorio commented Jan 21, 2025

Awesome, sounds good!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! A few comments about the code paths, regex init and should be good!

Comment on lines +34 to +37
Zamba2 requires you use `transformers` version 4.46.0 or higher:
```bash
pip install transformers>=4.46.0
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Zamba2 requires you use `transformers` version 4.46.0 or higher:
```bash
pip install transformers>=4.46.0
```
Zamba2 requires you use `transformers` version 4.48.0 or higher:
```bash
pip install transformers>=4.48.0

Comment on lines +93 to +101
def layer_type_list(config: Zamba2Config):
"""
Returns list of layer ids containing hybrid layers
"""
output_list = []
for index, type in enumerate(config.layers_block_type):
if type == "hybrid":
output_list.append(index)
return output_list
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we have this when we can simply store the explicit list in the config?

Comment on lines +82 to +89
def count_mem_blocks_in_config(config: Zamba2Config):
"""
Count number of shared blocks
"""
num_gs = 0
for val in config.layers_block_type:
if val == "hybrid":
num_gs += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this + it's only used once, not sure it's worth doing this

Comment on lines +132 to +150
self.conv_states = {
i: torch.zeros(
batch_size,
self.intermediate_size + 2 * config.mamba_ngroups * config.mamba_d_state,
self.conv_kernel_size,
device=device,
dtype=dtype,
)
for i in range(config.num_hidden_layers)
}
self.ssm_states = {
i: torch.zeros(
batch_size, self.n_mamba_heads, config.mamba_headdim, self.ssm_state_size, device=device, dtype=dtype
)
for i in range(config.num_hidden_layers)
}
for i in range(config.num_hidden_layers):
if self.layers_block_type[i] == "hybrid":
self.transformer_layers.append(i)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a single for loop should suffice here

self.self_attn = Zamba2Attention(config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id)
self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id)

def forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forward can be the same as ZambaAttentionDecoderLayer no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can prob remove the cache_positions as they are not used in both modeling

self, shared_transformer: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer
):
super().__init__(shared_transformer, linear, mamba)
del self.shared_transf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow I wish I caught this when reviewing the original model 🤣

Comment on lines +980 to +1000
ZAMBA2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)

This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.

Parameters:
config ([`Zamba2Config`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""


@add_start_docstrings(
"The bare Zamba2 Model outputting raw hidden-states without any specific head on top.",
ZAMBA2_START_DOCSTRING,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure removing them can work with auto renaming!

Comment on lines +1109 to +1112
@add_start_docstrings(
"The bare Zamba2 Model outputting raw hidden-states without any specific head on top.",
ZAMBA2_START_DOCSTRING,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, if they are the same input as zamba, you don't need to explicitly write these

"shared_transformer.pre_ff_layernorm.weight",
]
self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]]
if self.config.use_shared_mlp_adapter:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment about code path, which models have this set to true / false?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • tied key supports regex patter, we should never have to add all of themmanually like this

, dtype=torch.float32) # fmt: skip

torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1e-3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's missing a test on cpu with the sow forward!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants