From 61e976c993fc6b9d367e83ecb5364a973e197985 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Tue, 17 Sep 2024 00:31:25 +0200 Subject: [PATCH 1/3] use kernel for dt calculations --- src/transformers/models/mamba2/modeling_mamba2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 69390ea9ad2b..7b414ff9570d 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -358,7 +358,6 @@ def cuda_kernels_forward( dim=-1, ) - time_step = nn.functional.softplus(time_step + self.dt_bias) # 1D Convolution if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( @@ -391,6 +390,8 @@ def cuda_kernels_forward( z=None, seq_idx=None, return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: From fe01458a0a1759bf97a3cd1908a0fefd80424ac2 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 18 Sep 2024 18:46:23 +0200 Subject: [PATCH 2/3] add small test --- tests/models/mamba2/test_modeling_mamba2.py | 26 ++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index a1e2138d4d6d..55c18abe6b96 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -35,7 +35,7 @@ Mamba2ForCausalLM, Mamba2Model, ) - from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache + from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0 else: is_torch_greater_or_equal_than_2_0 = False @@ -378,3 +378,27 @@ def test_batched_equivalence_without_cache(self): individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True) individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0] self.assertEqual(individual_output[:100], batched_output[index_gen][:100]) + + @slow + @require_torch_gpu + def test_mamba2_mixer_train_vs_eval_equivalence(self): + # Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63 + # Credit to zhixuan-lin + + B, T, D = 4, 512, 768 + dtype = torch.bfloat16 + config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1) + + torch.manual_seed(42) + with torch.amp.autocast(device_type="cuda", dtype=dtype): + with torch.no_grad(): + mixer = Mamba2Mixer(config, layer_idx=0).to("cuda") + hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda") + + mixer.train() + out_train = mixer(hidden_states) + + mixer.eval() + out_eval = mixer(hidden_states) + + self.assertTrue(torch.allclose(out_train, out_eval, atol=1e-3)) From ff0cb7cdbfd41cd07967a509e7efba41bc518a43 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 18 Sep 2024 18:46:47 +0200 Subject: [PATCH 3/3] [run-slow] mamba2