From b72e876b06ff7df87b4d1a99ed87d81391f14fbc Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 14 Aug 2024 17:05:32 +0200 Subject: [PATCH 1/3] mamba2 uses norm_before_gate=False --- .../models/mamba2/configuration_mamba2.py | 4 ++-- .../models/mamba2/modeling_mamba2.py | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/mamba2/configuration_mamba2.py b/src/transformers/models/mamba2/configuration_mamba2.py index e3dcb63011d2..4b3a95b44693 100644 --- a/src/transformers/models/mamba2/configuration_mamba2.py +++ b/src/transformers/models/mamba2/configuration_mamba2.py @@ -83,7 +83,7 @@ class Mamba2Config(PretrainedConfig): Whether or not to rescale `out_proj` weights when initializing. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the cache should be used. - norm_before_gate (`bool`, *optional*, defaults to `True`): + norm_before_gate (`bool`, *optional*, defaults to `False`): Option of cuda kernels -whether to normalize before the gate or not. rms_norm (`bool`, *optional*, defaults to `True`): Whether to use RMS norm or not. @@ -137,7 +137,7 @@ def __init__( time_step_limit=(0.0, float("inf")), rescale_prenorm_residual=False, use_cache=True, - norm_before_gate=True, + norm_before_gate=False, rms_norm=True, chunk_size=256, tie_word_embeddings=False, diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index bf993ad2f311..d7dfb9b20a2a 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -170,21 +170,27 @@ def reset(self): class MambaRMSNormGated(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, eps=1e-6, norm_before_gate=False): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + self.norm_before_gate = norm_before_gate def forward(self, hidden_states, gate=None): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) - if gate is not None: + if gate is not None and not self.norm_before_gate: hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) + variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states - return self.weight * hidden_states.to(input_dtype) + if gate is not None and self.norm_before_gate: + hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) + + return hidden_states.to(input_dtype) class Mamba2Mixer(nn.Module): @@ -248,7 +254,9 @@ def __init__(self, config: Mamba2Config, layer_idx: int): A = torch.arange(1, self.num_heads + 1) self.A_log = nn.Parameter(torch.log(A)) self.A_log._no_weight_decay = True - self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) + self.norm = MambaRMSNormGated( + self.intermediate_size, eps=self.layer_norm_epsilon, norm_before_gate=config.norm_before_gate + ) self.D = nn.Parameter(torch.ones(self.num_heads)) self.D._no_weight_decay = True From 4f0ce72f29de2d454cff3dcd5810c5367dd72a5c Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 14 Aug 2024 17:41:15 +0200 Subject: [PATCH 2/3] small nit --- src/transformers/models/mamba2/modeling_mamba2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index d7dfb9b20a2a..ce2efb60025e 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -255,7 +255,7 @@ def __init__(self, config: Mamba2Config, layer_idx: int): self.A_log = nn.Parameter(torch.log(A)) self.A_log._no_weight_decay = True self.norm = MambaRMSNormGated( - self.intermediate_size, eps=self.layer_norm_epsilon, norm_before_gate=config.norm_before_gate + self.intermediate_size, eps=self.layer_norm_epsilon, norm_before_gate=self.norm_before_gate ) self.D = nn.Parameter(torch.ones(self.num_heads)) self.D._no_weight_decay = True From 7d01af04218a074727c304d474a94847c7003f12 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Tue, 20 Aug 2024 15:57:19 +0200 Subject: [PATCH 3/3] remove norm_before_gate flag and follow False path only --- .../models/mamba2/configuration_mamba2.py | 4 ---- .../models/mamba2/modeling_mamba2.py | 19 +++++-------------- 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/mamba2/configuration_mamba2.py b/src/transformers/models/mamba2/configuration_mamba2.py index 4b3a95b44693..7a690dceb1c4 100644 --- a/src/transformers/models/mamba2/configuration_mamba2.py +++ b/src/transformers/models/mamba2/configuration_mamba2.py @@ -83,8 +83,6 @@ class Mamba2Config(PretrainedConfig): Whether or not to rescale `out_proj` weights when initializing. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the cache should be used. - norm_before_gate (`bool`, *optional*, defaults to `False`): - Option of cuda kernels -whether to normalize before the gate or not. rms_norm (`bool`, *optional*, defaults to `True`): Whether to use RMS norm or not. chunk_size (`int`, *optional*, defaults to 256): @@ -137,7 +135,6 @@ def __init__( time_step_limit=(0.0, float("inf")), rescale_prenorm_residual=False, use_cache=True, - norm_before_gate=False, rms_norm=True, chunk_size=256, tie_word_embeddings=False, @@ -168,7 +165,6 @@ def __init__( self.n_groups = n_groups self.num_heads = num_heads self.head_dim = head_dim - self.norm_before_gate = norm_before_gate self.rms_norm = rms_norm self.state_size = state_size self.chunk_size = chunk_size diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index ce2efb60025e..69390ea9ad2b 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -170,27 +170,21 @@ def reset(self): class MambaRMSNormGated(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-6, norm_before_gate=False): + def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - self.norm_before_gate = norm_before_gate def forward(self, hidden_states, gate=None): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) - if gate is not None and not self.norm_before_gate: + if gate is not None: hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) - variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = self.weight * hidden_states - - if gate is not None and self.norm_before_gate: - hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) - return hidden_states.to(input_dtype) + return self.weight * hidden_states.to(input_dtype) class Mamba2Mixer(nn.Module): @@ -214,7 +208,6 @@ def __init__(self, config: Mamba2Config, layer_idx: int): self.activation = config.hidden_act self.act = ACT2FN[config.hidden_act] - self.norm_before_gate = config.norm_before_gate self.layer_norm_epsilon = config.layer_norm_epsilon self.rms_norm = config.rms_norm @@ -254,9 +247,7 @@ def __init__(self, config: Mamba2Config, layer_idx: int): A = torch.arange(1, self.num_heads + 1) self.A_log = nn.Parameter(torch.log(A)) self.A_log._no_weight_decay = True - self.norm = MambaRMSNormGated( - self.intermediate_size, eps=self.layer_norm_epsilon, norm_before_gate=self.norm_before_gate - ) + self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) self.D = nn.Parameter(torch.ones(self.num_heads)) self.D._no_weight_decay = True @@ -355,7 +346,7 @@ def cuda_kernels_forward( outproj_bias=self.out_proj.bias, headdim=self.head_dim, ngroups=self.n_groups, - norm_before_gate=self.norm_before_gate, + norm_before_gate=False, return_final_states=True, **dt_limit_kwargs, )