Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mamba2] Move dt calculations to kernel #33520

Merged
merged 3 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/transformers/models/mamba2/modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ def cuda_kernels_forward(
dim=-1,
)

time_step = nn.functional.softplus(time_step + self.dt_bias)
# 1D Convolution
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
hidden_states_B_C = self.act(
Expand Down Expand Up @@ -391,6 +390,8 @@ def cuda_kernels_forward(
z=None,
seq_idx=None,
return_final_states=True,
dt_bias=self.dt_bias,
dt_softplus=True,
**dt_limit_kwargs,
)
if ssm_state is not None and cache_params is not None:
Expand Down
26 changes: 25 additions & 1 deletion tests/models/mamba2/test_modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
Mamba2ForCausalLM,
Mamba2Model,
)
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
else:
is_torch_greater_or_equal_than_2_0 = False
Expand Down Expand Up @@ -378,3 +378,27 @@ def test_batched_equivalence_without_cache(self):
individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True)
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])

@slow
@require_torch_gpu
def test_mamba2_mixer_train_vs_eval_equivalence(self):
# Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63
# Credit to zhixuan-lin

B, T, D = 4, 512, 768
dtype = torch.bfloat16
config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1)

torch.manual_seed(42)
with torch.amp.autocast(device_type="cuda", dtype=dtype):
with torch.no_grad():
mixer = Mamba2Mixer(config, layer_idx=0).to("cuda")
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")

mixer.train()
out_train = mixer(hidden_states)

mixer.eval()
out_eval = mixer(hidden_states)

self.assertTrue(torch.allclose(out_train, out_eval, atol=1e-3))
Loading