Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
rheasukthanker committed Jun 19, 2024
1 parent 6d80a3b commit 4c7a46f
Showing 1 changed file with 42 additions and 12 deletions.
54 changes: 42 additions & 12 deletions test/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,33 @@
import pytest

attention_configs = {
"mha_fix_head_size": {"config":Config(n_embd=128, n_head=16, n_query_groups=16, head_size=64), "fix_head_size": True},
"gqa_fix_head_size": {"config":Config(n_embd=128, n_head=16, n_query_groups=2, head_size=64), "fix_head_size": True},
"mqa_fix_head_size": {"config":Config(n_embd=128, n_head=16, n_query_groups=1, head_size=64), "fix_head_size": True},
"mha_flexible_head_size": {"config":Config(n_embd=128, n_head=16, n_query_groups=16), "fix_head_size": False},
"gqa_flexible_head_size": {"config":Config(n_embd=128, n_head=16, n_query_groups=2), "fix_head_size": False},
"mqa_flexible_head_size": {"config":Config(n_embd=128, n_head=16, n_query_groups=1), "fix_head_size": False}
"mha_fix_head_size": {
"config": Config(n_embd=128, n_head=16, n_query_groups=16, head_size=64),
"fix_head_size": True,
},
"gqa_fix_head_size": {
"config": Config(n_embd=128, n_head=16, n_query_groups=2, head_size=64),
"fix_head_size": True,
},
"mqa_fix_head_size": {
"config": Config(n_embd=128, n_head=16, n_query_groups=1, head_size=64),
"fix_head_size": True,
},
"mha_flexible_head_size": {
"config": Config(n_embd=128, n_head=16, n_query_groups=16),
"fix_head_size": False,
},
"gqa_flexible_head_size": {
"config": Config(n_embd=128, n_head=16, n_query_groups=2),
"fix_head_size": False,
},
"mqa_flexible_head_size": {
"config": Config(n_embd=128, n_head=16, n_query_groups=1),
"fix_head_size": False,
},
}


def init_attention(config):
attention = CausalSelfAttention(config)
attention.attn.weight.data = torch.ones_like(attention.attn.weight.data)
Expand All @@ -21,6 +41,7 @@ def init_attention(config):
attention.proj.weight.data = torch.ones_like(attention.proj.weight.data)
return attention


def init_lit_attention(config):
attention = LitCausalSelfAttention(config)
attention.attn.weight.data = torch.ones_like(attention.attn.weight.data)
Expand All @@ -29,6 +50,7 @@ def init_lit_attention(config):
attention.proj.weight.data = torch.ones_like(attention.proj.weight.data)
return attention


@pytest.mark.parametrize("attention_config", attention_configs.keys())
def test_attention(attention_config):
config = attention_configs[attention_config]["config"]
Expand All @@ -51,9 +73,15 @@ def test_attention(attention_config):
lit_attention = init_lit_attention(config)
out_lit_large = lit_attention(input, mask=mask, cos=cos, sin=sin)

attention.set_sub_network(sub_network_n_embd=config.n_embd // 2, sub_network_n_head=config.n_head // 4)
cos, sin = build_rope_cache(seq_len, n_elem=int(config.rotary_percentage * attention.sub_network_head_size))
out_small = attention(input[:, :, :config.n_embd//2], mask=mask, cos=cos, sin=sin)
attention.set_sub_network(
sub_network_n_embd=config.n_embd // 2, sub_network_n_head=config.n_head // 4
)
cos, sin = build_rope_cache(
seq_len, n_elem=int(config.rotary_percentage * attention.sub_network_head_size)
)
out_small = attention(
input[:, :, : config.n_embd // 2], mask=mask, cos=cos, sin=sin
)

# check shape of sub-network attention
assert out_small.shape == (8, seq_len, config.n_embd // 2)
Expand All @@ -68,8 +96,10 @@ def test_attention(attention_config):
config.rope_n_elem = int(config.rotary_percentage * config.head_size)

lit_attention_small = init_lit_attention(config)

out_lit_small = lit_attention_small(input[:, :, :config.n_embd], mask=mask, cos=cos, sin=sin)

out_lit_small = lit_attention_small(
input[:, :, : config.n_embd], mask=mask, cos=cos, sin=sin
)

# check that our sub-networks the same output as equally sized LitGPT attention layer
assert torch.all(out_lit_small == out_small)
assert torch.all(out_lit_small == out_small)

0 comments on commit 4c7a46f

Please sign in to comment.