Skip to content

Commit

Permalink
add small test
Browse files Browse the repository at this point in the history
  • Loading branch information
vasqu committed Sep 19, 2024
1 parent 61e976c commit fe01458
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion tests/models/mamba2/test_modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
Mamba2ForCausalLM,
Mamba2Model,
)
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
else:
is_torch_greater_or_equal_than_2_0 = False
Expand Down Expand Up @@ -378,3 +378,27 @@ def test_batched_equivalence_without_cache(self):
individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True)
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])

@slow
@require_torch_gpu
def test_mamba2_mixer_train_vs_eval_equivalence(self):
# Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63
# Credit to zhixuan-lin

B, T, D = 4, 512, 768
dtype = torch.bfloat16
config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1)

torch.manual_seed(42)
with torch.amp.autocast(device_type="cuda", dtype=dtype):
with torch.no_grad():
mixer = Mamba2Mixer(config, layer_idx=0).to("cuda")
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")

mixer.train()
out_train = mixer(hidden_states)

mixer.eval()
out_eval = mixer(hidden_states)

self.assertTrue(torch.allclose(out_train, out_eval, atol=1e-3))

0 comments on commit fe01458

Please sign in to comment.