-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9fa4b94
commit 3b800f3
Showing
18 changed files
with
2,234 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
from __future__ import annotations | ||
|
||
import pytest | ||
import torch | ||
from litgpt.model import ( | ||
CausalSelfAttention as LitCausalSelfAttention, | ||
build_mask_cache, | ||
build_rope_cache, | ||
) | ||
|
||
from whittle.lora.config import LoRAConfig as Config | ||
from whittle.lora.lora_attention import CausalSelfAttention | ||
|
||
attention_configs = { | ||
"mha_fix_head_size_sliding": { | ||
"config": Config( | ||
n_embd=64, | ||
n_head=16, | ||
n_query_groups=16, | ||
head_size=64, | ||
sliding_window_size=256, | ||
sliding_window_layer_placing="interleaved", | ||
), | ||
"fix_head_size": True, | ||
}, | ||
"mha_fix_head_size": { | ||
"config": Config(n_embd=64, n_head=16, n_query_groups=16, head_size=64), | ||
"fix_head_size": True, | ||
}, | ||
"gqa_fix_head_size": { | ||
"config": Config(n_embd=64, n_head=16, n_query_groups=2, head_size=64), | ||
"fix_head_size": True, | ||
}, | ||
"mqa_fix_head_size": { | ||
"config": Config(n_embd=64, n_head=16, n_query_groups=1, head_size=64), | ||
"fix_head_size": True, | ||
}, | ||
"mha_flexible_head_size": { | ||
"config": Config(n_embd=64, n_head=16, n_query_groups=16), | ||
"fix_head_size": False, | ||
}, | ||
"gqa_flexible_head_size": { | ||
"config": Config(n_embd=64, n_head=16, n_query_groups=2), | ||
"fix_head_size": False, | ||
}, | ||
"mqa_flexible_head_size": { | ||
"config": Config(n_embd=64, n_head=16, n_query_groups=1), | ||
"fix_head_size": False, | ||
}, | ||
} | ||
|
||
|
||
def init_attention(config): | ||
attention = CausalSelfAttention(config, 2) | ||
torch.manual_seed(0) | ||
attention.attn.linear.linear.weight.data = torch.randn_like( | ||
attention.attn.linear.linear.weight.data | ||
) | ||
attention.attn.linear.linear.bias.data = torch.randn_like( | ||
attention.attn.linear.linear.bias.data | ||
) | ||
attention.proj.linear.bias.data = torch.randn_like(attention.proj.linear.bias.data) | ||
attention.proj.linear.weight.data = torch.randn_like( | ||
attention.proj.linear.weight.data | ||
) | ||
return attention | ||
|
||
|
||
def init_lit_attention(config): | ||
attention = LitCausalSelfAttention(config, 2) | ||
torch.manual_seed(0) | ||
attention.attn.weight.data = torch.randn_like(attention.attn.weight.data) | ||
attention.attn.bias.data = torch.randn_like(attention.attn.bias.data) | ||
attention.proj.bias.data = torch.randn_like(attention.proj.bias.data) | ||
attention.proj.weight.data = torch.randn_like(attention.proj.weight.data) | ||
return attention | ||
|
||
|
||
def init_lit_small_attention(config, base_attention, attention_super): | ||
attention = LitCausalSelfAttention(config, 2) | ||
torch.manual_seed(0) | ||
slices = tuple(slice(0, s) for s in attention.attn.weight.data.size())[1] | ||
qkv_indices = ( | ||
attention_super.qkv_indices | ||
if attention_super.qkv_indices is not None | ||
else slice(0, attention.attn.weight.data.size()[0]) | ||
) | ||
attention.attn.weight.data = base_attention.attn.weight.data[qkv_indices, :][ | ||
:, 0 : attention.attn.weight.data.size()[1] | ||
] | ||
attention.attn.bias.data = base_attention.attn.bias.data[qkv_indices] | ||
proj_indices = ( | ||
attention_super.proj_indices | ||
if attention_super.proj_indices is not None | ||
else slice(0, attention.proj.weight.data.size()[-1]) | ||
) | ||
slices = tuple(slice(0, s) for s in attention.proj.bias.data.size()) | ||
attention.proj.bias.data = base_attention.proj.bias.data[slices] | ||
|
||
attention.proj.weight.data = base_attention.proj.weight.data[ | ||
0 : attention.proj.weight.data.size()[0], : | ||
][:, proj_indices] | ||
return attention | ||
|
||
|
||
@pytest.mark.parametrize("attention_config", attention_configs.keys()) | ||
def test_attention(attention_config): | ||
config = attention_configs[attention_config]["config"] | ||
config.fix_head_size = attention_configs[attention_config]["fix_head_size"] | ||
if not config.fix_head_size: | ||
config.head_size = 32 | ||
config.max_seq_len = 512 | ||
config.rope_n_elem = int(config.rotary_percentage * config.head_size) | ||
|
||
seq_len = config.max_seq_len | ||
cos, sin = build_rope_cache(seq_len, n_elem=config.rope_n_elem) | ||
cos = cos[:seq_len] | ||
sin = sin[:seq_len] | ||
input = torch.rand(8, seq_len, config.n_embd) | ||
mask = build_mask_cache(seq_len) | ||
|
||
attention = init_attention(config) | ||
out_large = attention(input, mask=mask, cos=cos, sin=sin) | ||
|
||
# check shape of super network attention | ||
assert out_large.shape == (8, seq_len, config.n_embd) | ||
lit_attention = init_lit_attention(config) | ||
out_lit_large = lit_attention(input, mask=mask, cos=cos, sin=sin) | ||
if not config.fix_head_size: | ||
sub_network_head_size = config.head_size // 2 | ||
else: | ||
sub_network_head_size = config.head_size | ||
if config.n_query_groups == 1: | ||
sub_network_query_groups = 1 | ||
sub_network_n_head = config.n_head // 4 | ||
elif config.n_query_groups == config.n_head: | ||
sub_network_n_head = config.n_head // 4 | ||
sub_network_query_groups = sub_network_n_head | ||
else: | ||
sub_network_query_groups = config.n_query_groups // 2 | ||
sub_network_n_head = config.n_head // 2 | ||
attention.set_sub_network( | ||
sub_network_n_embd=config.n_embd // 2, | ||
sub_network_n_head=sub_network_n_head, | ||
sub_network_query_groups=sub_network_query_groups, | ||
sub_network_head_size=sub_network_head_size, | ||
) | ||
cos, sin = build_rope_cache( | ||
seq_len, n_elem=int(config.rotary_percentage * sub_network_head_size) | ||
) | ||
out_small = attention(input[:, :, : config.n_embd // 2], mask=mask, cos=cos, sin=sin) | ||
|
||
# check shape of sub-network attention | ||
assert out_small.shape == (8, seq_len, config.n_embd // 2) | ||
|
||
# check that our custom model produces the same output as LitGPT | ||
assert torch.all(out_lit_large == out_large) | ||
config.n_embd = attention.sub_network_n_embd | ||
if config.n_query_groups == config.n_head: | ||
config.n_head = attention.sub_network_n_head | ||
else: | ||
config.n_head = ( | ||
attention.sub_network_n_head // config.n_query_groups | ||
) * attention.sub_network_query_groups | ||
config.n_query_groups = attention.sub_network_query_groups | ||
config.head_size = attention.sub_network_head_size | ||
config.rope_n_elem = int(config.rotary_percentage * config.head_size) | ||
lit_attention_small = init_lit_small_attention(config, lit_attention, attention) | ||
|
||
out_lit_small = lit_attention_small( | ||
input[:, :, : config.n_embd], mask=mask, cos=cos, sin=sin | ||
) | ||
# check that our sub-networks the same output as equally sized LitGPT attention layer | ||
assert torch.all(out_lit_small == out_small) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from __future__ import annotations | ||
|
||
import torch | ||
from litgpt import Config as LitConfig | ||
from litgpt.model import ( | ||
Block as LitBlock, | ||
build_mask_cache, | ||
build_rope_cache, | ||
) | ||
|
||
from whittle.lora.config import LoRAConfig as Config | ||
from whittle.lora.lora_block import LoRABlock as Block | ||
|
||
|
||
def test_block(): | ||
config = Config() | ||
config.n_embd = 64 | ||
config.n_head = 8 | ||
config.n_query_groups = 4 | ||
config.head_size = 8 | ||
config.intermediate_size = 64 * 4 | ||
config.fix_head_size = False | ||
config.mlp_class_name = "LLaMAMLP" | ||
config.max_seq_len = 512 | ||
config.rotary_percentage = 0.25 | ||
config.rope_n_elem = int(config.rotary_percentage * config.head_size) | ||
cos, sin = build_rope_cache(config.max_seq_len, n_elem=config.rope_n_elem) | ||
litconfig = LitConfig() | ||
litconfig.n_embd = 64 | ||
litconfig.n_head = 8 | ||
litconfig.n_query_groups = 4 | ||
litconfig.head_size = 8 | ||
litconfig.intermediate_size = 64 * 4 | ||
litconfig.fix_head_size = False | ||
litconfig.mlp_class_name = "LLaMAMLP" | ||
litconfig.max_seq_len = 512 | ||
litconfig.rotary_percentage = 0.25 | ||
litconfig.rope_n_elem = int(litconfig.rotary_percentage * litconfig.head_size) | ||
block = Block(config, 0) | ||
input = torch.rand(8, 512, 64) | ||
mask = build_mask_cache(512) | ||
block.attn.attn.linear.linear.weight.data = torch.ones_like( | ||
block.attn.attn.linear.linear.weight.data | ||
) | ||
block.attn.attn.linear.linear.bias.data = torch.ones_like( | ||
block.attn.attn.linear.linear.bias.data | ||
) | ||
block.attn.proj.linear.bias.data = torch.ones_like(block.attn.proj.linear.bias.data) | ||
block.attn.proj.linear.weight.data = torch.ones_like( | ||
block.attn.proj.linear.weight.data | ||
) | ||
block.mlp.fc_1.linear.weight.data = torch.ones_like(block.mlp.fc_1.linear.weight.data) | ||
block.mlp.fc_1.linear.bias.data = torch.ones_like(block.mlp.fc_1.linear.bias.data) | ||
block.mlp.fc_2.linear.weight.data = torch.ones_like(block.mlp.fc_2.linear.weight.data) | ||
block.mlp.fc_2.linear.bias.data = torch.ones_like(block.mlp.fc_2.linear.bias.data) | ||
block.mlp.proj.linear.weight.data = torch.ones_like(block.mlp.proj.linear.weight.data) | ||
block.mlp.proj.linear.bias.data = torch.ones_like(block.mlp.proj.linear.bias.data) | ||
block.reset_super_network() | ||
out_large = block(input, cos, sin, mask) | ||
assert out_large.shape == (8, 512, 64) | ||
block.set_sub_network( | ||
sub_network_n_embd=32, | ||
sub_network_intermediate_size=32 * 4, | ||
sub_network_num_heads=8, | ||
sub_network_query_groups=config.n_query_groups // 2, | ||
sub_network_head_size=32 // 4, | ||
) | ||
out_small = block(input[:, :, :32], cos, sin, mask) | ||
assert out_small.shape == (8, 512, 32) | ||
|
||
lit_block = LitBlock(litconfig, 0) | ||
print(lit_block) | ||
lit_block.attn.attn.weight.data = torch.ones_like(lit_block.attn.attn.weight.data) | ||
lit_block.attn.attn.bias.data = torch.ones_like(lit_block.attn.attn.bias.data) | ||
lit_block.attn.proj.bias.data = torch.ones_like(lit_block.attn.proj.bias.data) | ||
lit_block.attn.proj.weight.data = torch.ones_like(lit_block.attn.proj.weight.data) | ||
lit_block.mlp.fc_1.weight.data = torch.ones_like(lit_block.mlp.fc_1.weight.data) | ||
lit_block.mlp.fc_1.bias.data = torch.ones_like(lit_block.mlp.fc_1.bias.data) | ||
lit_block.mlp.fc_2.weight.data = torch.ones_like(lit_block.mlp.fc_2.weight.data) | ||
lit_block.mlp.fc_2.bias.data = torch.ones_like(lit_block.mlp.fc_2.bias.data) | ||
lit_block.mlp.proj.weight.data = torch.ones_like(lit_block.mlp.proj.weight.data) | ||
lit_block.mlp.proj.bias.data = torch.ones_like(lit_block.mlp.proj.bias.data) | ||
out_lit_large = lit_block(input, cos, sin, mask) | ||
assert torch.all(out_lit_large == out_large) | ||
|
||
litconfig.n_embd = 32 | ||
litconfig.n_head = 4 | ||
litconfig.n_query_groups = 2 | ||
litconfig.intermediate_size = 32 * 4 | ||
lit_block_small = LitBlock(litconfig, 0) | ||
lit_block_small.attn.attn.weight.data = torch.ones_like( | ||
lit_block_small.attn.attn.weight.data | ||
) | ||
lit_block_small.attn.attn.bias.data = torch.ones_like( | ||
lit_block_small.attn.attn.bias.data | ||
) | ||
lit_block_small.attn.proj.bias.data = torch.ones_like( | ||
lit_block_small.attn.proj.bias.data | ||
) | ||
lit_block_small.attn.proj.weight.data = torch.ones_like( | ||
lit_block_small.attn.proj.weight.data | ||
) | ||
lit_block_small.mlp.fc_1.weight.data = torch.ones_like( | ||
lit_block_small.mlp.fc_1.weight.data | ||
) | ||
lit_block_small.mlp.fc_1.bias.data = torch.ones_like( | ||
lit_block_small.mlp.fc_1.bias.data | ||
) | ||
lit_block_small.mlp.fc_2.weight.data = torch.ones_like( | ||
lit_block_small.mlp.fc_2.weight.data | ||
) | ||
lit_block_small.mlp.fc_2.bias.data = torch.ones_like( | ||
lit_block_small.mlp.fc_2.bias.data | ||
) | ||
lit_block_small.mlp.proj.weight.data = torch.ones_like( | ||
lit_block_small.mlp.proj.weight.data | ||
) | ||
lit_block_small.mlp.proj.bias.data = torch.ones_like( | ||
lit_block_small.mlp.proj.bias.data | ||
) | ||
out_lit_small = lit_block_small(input[:, :, :32], cos, sin, mask) | ||
assert torch.all(out_lit_small == out_small) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from __future__ import annotations | ||
|
||
import torch | ||
|
||
from whittle.lora.lora_embedding import LoRAEmbedding as Embedding | ||
|
||
|
||
def test_embedding(): | ||
input_features = torch.randint(low=1, high=64, size=(4, 8)) | ||
emb = Embedding(64, 32) | ||
|
||
out = emb(input_features) | ||
assert out.shape == (4, 8, 32) | ||
emb.set_sub_network(16) | ||
out = emb(input_features) | ||
assert out.shape == (4, 8, 16) | ||
emb.set_sub_network(32) | ||
out = emb(input_features) | ||
assert out.shape == (4, 8, 32) | ||
|
||
emb.embedding.weight.data = torch.randn_like(emb.embedding.weight.data) | ||
emb.set_sub_network(16) | ||
out_small = emb(input_features) | ||
emb.set_sub_network(32) | ||
out_large = emb(input_features) | ||
|
||
small_layer = torch.nn.Embedding(64, 16) | ||
|
||
small_layer.weight.data = emb.embedding.weight.data[:, :16] | ||
|
||
out_small_layer = small_layer(input_features) | ||
|
||
large_layer = torch.nn.Embedding(64, 32) | ||
large_layer.weight.data = emb.embedding.weight.data[:, :32] | ||
out_large_layer = large_layer(input_features) | ||
|
||
assert torch.all(out_small == out_small_layer) | ||
assert torch.all(out_large == out_large_layer) |
Oops, something went wrong.