-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Zamba2 #34517
base: main
Are you sure you want to change the base?
Add Zamba2 #34517
Conversation
Rebase zamba2
Hey @Arthur, Thank you again for your help in getting Zamba2 into A few remarks, mostly related to
I carefully compared
Looking forward to your feedback. Thanks so much! |
rebase on upstream
Hi @Cyrilvallez and @ArthurZucker, I updated the attention forward to the new standard of I ran all final tests, including |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work for the refactor! Almost ready, left some final comments but overall quite nice! 🤗
"ZambaModelTester", | ||
"Zamba2ModelTester", | ||
"RwkvModelTester", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ydshieh here to ensure this change is necessary, as I'm not familiar with this new part!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ydshieh for context, when running this test the config of the model is forced to have num_hidden_layers=1
but other parameters of the config are not updated accordingly so when the model is initialized it errors out as these params are not consistently updated. It's probably also the reason why Zamba was added to this list I imagine.
Thank you @Cyrilvallez for the review. I addressed the comments above, although there are a couple of pending points. All zamba-related tests appear to pass. |
Hello @Cyrilvallez, I ran all model tests on two GPUs and after a couple of minor fixes everything appears to work now. I'm skipping this test as it gives an error related to mamba2 kernels. I indeed verified that mamba2 skips that test here. Separately, when running
which I was not getting before despite this part was identical. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Let's just wait for #35795 which will get rid of the CI failure for modular conversion! Sorry about that, and thanks for being so patient with us 🙏🙏🤗
Great work!
Awesome, sounds good! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! A few comments about the code paths, regex init and should be good!
Zamba2 requires you use `transformers` version 4.46.0 or higher: | ||
```bash | ||
pip install transformers>=4.46.0 | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Zamba2 requires you use `transformers` version 4.46.0 or higher: | |
```bash | |
pip install transformers>=4.46.0 | |
``` | |
Zamba2 requires you use `transformers` version 4.48.0 or higher: | |
```bash | |
pip install transformers>=4.48.0 |
def layer_type_list(config: Zamba2Config): | ||
""" | ||
Returns list of layer ids containing hybrid layers | ||
""" | ||
output_list = [] | ||
for index, type in enumerate(config.layers_block_type): | ||
if type == "hybrid": | ||
output_list.append(index) | ||
return output_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why we have this when we can simply store the explicit list in the config?
def count_mem_blocks_in_config(config: Zamba2Config): | ||
""" | ||
Count number of shared blocks | ||
""" | ||
num_gs = 0 | ||
for val in config.layers_block_type: | ||
if val == "hybrid": | ||
num_gs += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for this + it's only used once, not sure it's worth doing this
self.conv_states = { | ||
i: torch.zeros( | ||
batch_size, | ||
self.intermediate_size + 2 * config.mamba_ngroups * config.mamba_d_state, | ||
self.conv_kernel_size, | ||
device=device, | ||
dtype=dtype, | ||
) | ||
for i in range(config.num_hidden_layers) | ||
} | ||
self.ssm_states = { | ||
i: torch.zeros( | ||
batch_size, self.n_mamba_heads, config.mamba_headdim, self.ssm_state_size, device=device, dtype=dtype | ||
) | ||
for i in range(config.num_hidden_layers) | ||
} | ||
for i in range(config.num_hidden_layers): | ||
if self.layers_block_type[i] == "hybrid": | ||
self.transformer_layers.append(i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a single for loop should suffice here
self.self_attn = Zamba2Attention(config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id) | ||
self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id) | ||
|
||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forward can be the same as ZambaAttentionDecoderLayer
no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can prob remove the cache_positions as they are not used in both modeling
self, shared_transformer: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer | ||
): | ||
super().__init__(shared_transformer, linear, mamba) | ||
del self.shared_transf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow I wish I caught this when reviewing the original model 🤣
ZAMBA2_START_DOCSTRING = r""" | ||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the | ||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads | ||
etc.) | ||
|
||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. | ||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage | ||
and behavior. | ||
|
||
Parameters: | ||
config ([`Zamba2Config`]): | ||
Model configuration class with all the parameters of the model. Initializing with a config file does not | ||
load the weights associated with the model, only the configuration. Check out the | ||
[`~PreTrainedModel.from_pretrained`] method to load the model weights. | ||
""" | ||
|
||
|
||
@add_start_docstrings( | ||
"The bare Zamba2 Model outputting raw hidden-states without any specific head on top.", | ||
ZAMBA2_START_DOCSTRING, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pretty sure removing them can work with auto renaming!
@add_start_docstrings( | ||
"The bare Zamba2 Model outputting raw hidden-states without any specific head on top.", | ||
ZAMBA2_START_DOCSTRING, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, if they are the same input as zamba, you don't need to explicitly write these
"shared_transformer.pre_ff_layernorm.weight", | ||
] | ||
self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] | ||
if self.config.use_shared_mlp_adapter: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment about code path, which models have this set to true / false?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- tied key supports regex patter, we should never have to add all of themmanually like this
, dtype=torch.float32) # fmt: skip | ||
|
||
torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3) | ||
torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's missing a test on cpu with the sow forward!
What does this PR do?
Please include support for Zamba2 architecture created by Zyphra Technologies.
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker