diff --git a/test/test_attention.py b/test/test_attention.py index cb3b7d65..19e5c2d6 100644 --- a/test/test_attention.py +++ b/test/test_attention.py @@ -6,13 +6,33 @@ import pytest attention_configs = { - "mha_fix_head_size": {"config":Config(n_embd=128, n_head=16, n_query_groups=16, head_size=64), "fix_head_size": True}, - "gqa_fix_head_size": {"config":Config(n_embd=128, n_head=16, n_query_groups=2, head_size=64), "fix_head_size": True}, - "mqa_fix_head_size": {"config":Config(n_embd=128, n_head=16, n_query_groups=1, head_size=64), "fix_head_size": True}, - "mha_flexible_head_size": {"config":Config(n_embd=128, n_head=16, n_query_groups=16), "fix_head_size": False}, - "gqa_flexible_head_size": {"config":Config(n_embd=128, n_head=16, n_query_groups=2), "fix_head_size": False}, - "mqa_flexible_head_size": {"config":Config(n_embd=128, n_head=16, n_query_groups=1), "fix_head_size": False} + "mha_fix_head_size": { + "config": Config(n_embd=128, n_head=16, n_query_groups=16, head_size=64), + "fix_head_size": True, + }, + "gqa_fix_head_size": { + "config": Config(n_embd=128, n_head=16, n_query_groups=2, head_size=64), + "fix_head_size": True, + }, + "mqa_fix_head_size": { + "config": Config(n_embd=128, n_head=16, n_query_groups=1, head_size=64), + "fix_head_size": True, + }, + "mha_flexible_head_size": { + "config": Config(n_embd=128, n_head=16, n_query_groups=16), + "fix_head_size": False, + }, + "gqa_flexible_head_size": { + "config": Config(n_embd=128, n_head=16, n_query_groups=2), + "fix_head_size": False, + }, + "mqa_flexible_head_size": { + "config": Config(n_embd=128, n_head=16, n_query_groups=1), + "fix_head_size": False, + }, } + + def init_attention(config): attention = CausalSelfAttention(config) attention.attn.weight.data = torch.ones_like(attention.attn.weight.data) @@ -21,6 +41,7 @@ def init_attention(config): attention.proj.weight.data = torch.ones_like(attention.proj.weight.data) return attention + def init_lit_attention(config): attention = LitCausalSelfAttention(config) attention.attn.weight.data = torch.ones_like(attention.attn.weight.data) @@ -29,6 +50,7 @@ def init_lit_attention(config): attention.proj.weight.data = torch.ones_like(attention.proj.weight.data) return attention + @pytest.mark.parametrize("attention_config", attention_configs.keys()) def test_attention(attention_config): config = attention_configs[attention_config]["config"] @@ -51,9 +73,15 @@ def test_attention(attention_config): lit_attention = init_lit_attention(config) out_lit_large = lit_attention(input, mask=mask, cos=cos, sin=sin) - attention.set_sub_network(sub_network_n_embd=config.n_embd // 2, sub_network_n_head=config.n_head // 4) - cos, sin = build_rope_cache(seq_len, n_elem=int(config.rotary_percentage * attention.sub_network_head_size)) - out_small = attention(input[:, :, :config.n_embd//2], mask=mask, cos=cos, sin=sin) + attention.set_sub_network( + sub_network_n_embd=config.n_embd // 2, sub_network_n_head=config.n_head // 4 + ) + cos, sin = build_rope_cache( + seq_len, n_elem=int(config.rotary_percentage * attention.sub_network_head_size) + ) + out_small = attention( + input[:, :, : config.n_embd // 2], mask=mask, cos=cos, sin=sin + ) # check shape of sub-network attention assert out_small.shape == (8, seq_len, config.n_embd // 2) @@ -68,8 +96,10 @@ def test_attention(attention_config): config.rope_n_elem = int(config.rotary_percentage * config.head_size) lit_attention_small = init_lit_attention(config) - - out_lit_small = lit_attention_small(input[:, :, :config.n_embd], mask=mask, cos=cos, sin=sin) + + out_lit_small = lit_attention_small( + input[:, :, : config.n_embd], mask=mask, cos=cos, sin=sin + ) # check that our sub-networks the same output as equally sized LitGPT attention layer - assert torch.all(out_lit_small == out_small) \ No newline at end of file + assert torch.all(out_lit_small == out_small)