Skip to content

Commit

Permalink
feat: add lora supernet (#245)
Browse files Browse the repository at this point in the history
  • Loading branch information
rheasukthanker authored Feb 12, 2025
1 parent 9fa4b94 commit 3b800f3
Show file tree
Hide file tree
Showing 18 changed files with 2,234 additions and 1 deletion.
174 changes: 174 additions & 0 deletions test/test_lora_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from __future__ import annotations

import pytest
import torch
from litgpt.model import (
CausalSelfAttention as LitCausalSelfAttention,
build_mask_cache,
build_rope_cache,
)

from whittle.lora.config import LoRAConfig as Config
from whittle.lora.lora_attention import CausalSelfAttention

attention_configs = {
"mha_fix_head_size_sliding": {
"config": Config(
n_embd=64,
n_head=16,
n_query_groups=16,
head_size=64,
sliding_window_size=256,
sliding_window_layer_placing="interleaved",
),
"fix_head_size": True,
},
"mha_fix_head_size": {
"config": Config(n_embd=64, n_head=16, n_query_groups=16, head_size=64),
"fix_head_size": True,
},
"gqa_fix_head_size": {
"config": Config(n_embd=64, n_head=16, n_query_groups=2, head_size=64),
"fix_head_size": True,
},
"mqa_fix_head_size": {
"config": Config(n_embd=64, n_head=16, n_query_groups=1, head_size=64),
"fix_head_size": True,
},
"mha_flexible_head_size": {
"config": Config(n_embd=64, n_head=16, n_query_groups=16),
"fix_head_size": False,
},
"gqa_flexible_head_size": {
"config": Config(n_embd=64, n_head=16, n_query_groups=2),
"fix_head_size": False,
},
"mqa_flexible_head_size": {
"config": Config(n_embd=64, n_head=16, n_query_groups=1),
"fix_head_size": False,
},
}


def init_attention(config):
attention = CausalSelfAttention(config, 2)
torch.manual_seed(0)
attention.attn.linear.linear.weight.data = torch.randn_like(
attention.attn.linear.linear.weight.data
)
attention.attn.linear.linear.bias.data = torch.randn_like(
attention.attn.linear.linear.bias.data
)
attention.proj.linear.bias.data = torch.randn_like(attention.proj.linear.bias.data)
attention.proj.linear.weight.data = torch.randn_like(
attention.proj.linear.weight.data
)
return attention


def init_lit_attention(config):
attention = LitCausalSelfAttention(config, 2)
torch.manual_seed(0)
attention.attn.weight.data = torch.randn_like(attention.attn.weight.data)
attention.attn.bias.data = torch.randn_like(attention.attn.bias.data)
attention.proj.bias.data = torch.randn_like(attention.proj.bias.data)
attention.proj.weight.data = torch.randn_like(attention.proj.weight.data)
return attention


def init_lit_small_attention(config, base_attention, attention_super):
attention = LitCausalSelfAttention(config, 2)
torch.manual_seed(0)
slices = tuple(slice(0, s) for s in attention.attn.weight.data.size())[1]
qkv_indices = (
attention_super.qkv_indices
if attention_super.qkv_indices is not None
else slice(0, attention.attn.weight.data.size()[0])
)
attention.attn.weight.data = base_attention.attn.weight.data[qkv_indices, :][
:, 0 : attention.attn.weight.data.size()[1]
]
attention.attn.bias.data = base_attention.attn.bias.data[qkv_indices]
proj_indices = (
attention_super.proj_indices
if attention_super.proj_indices is not None
else slice(0, attention.proj.weight.data.size()[-1])
)
slices = tuple(slice(0, s) for s in attention.proj.bias.data.size())
attention.proj.bias.data = base_attention.proj.bias.data[slices]

attention.proj.weight.data = base_attention.proj.weight.data[
0 : attention.proj.weight.data.size()[0], :
][:, proj_indices]
return attention


@pytest.mark.parametrize("attention_config", attention_configs.keys())
def test_attention(attention_config):
config = attention_configs[attention_config]["config"]
config.fix_head_size = attention_configs[attention_config]["fix_head_size"]
if not config.fix_head_size:
config.head_size = 32
config.max_seq_len = 512
config.rope_n_elem = int(config.rotary_percentage * config.head_size)

seq_len = config.max_seq_len
cos, sin = build_rope_cache(seq_len, n_elem=config.rope_n_elem)
cos = cos[:seq_len]
sin = sin[:seq_len]
input = torch.rand(8, seq_len, config.n_embd)
mask = build_mask_cache(seq_len)

attention = init_attention(config)
out_large = attention(input, mask=mask, cos=cos, sin=sin)

# check shape of super network attention
assert out_large.shape == (8, seq_len, config.n_embd)
lit_attention = init_lit_attention(config)
out_lit_large = lit_attention(input, mask=mask, cos=cos, sin=sin)
if not config.fix_head_size:
sub_network_head_size = config.head_size // 2
else:
sub_network_head_size = config.head_size
if config.n_query_groups == 1:
sub_network_query_groups = 1
sub_network_n_head = config.n_head // 4
elif config.n_query_groups == config.n_head:
sub_network_n_head = config.n_head // 4
sub_network_query_groups = sub_network_n_head
else:
sub_network_query_groups = config.n_query_groups // 2
sub_network_n_head = config.n_head // 2
attention.set_sub_network(
sub_network_n_embd=config.n_embd // 2,
sub_network_n_head=sub_network_n_head,
sub_network_query_groups=sub_network_query_groups,
sub_network_head_size=sub_network_head_size,
)
cos, sin = build_rope_cache(
seq_len, n_elem=int(config.rotary_percentage * sub_network_head_size)
)
out_small = attention(input[:, :, : config.n_embd // 2], mask=mask, cos=cos, sin=sin)

# check shape of sub-network attention
assert out_small.shape == (8, seq_len, config.n_embd // 2)

# check that our custom model produces the same output as LitGPT
assert torch.all(out_lit_large == out_large)
config.n_embd = attention.sub_network_n_embd
if config.n_query_groups == config.n_head:
config.n_head = attention.sub_network_n_head
else:
config.n_head = (
attention.sub_network_n_head // config.n_query_groups
) * attention.sub_network_query_groups
config.n_query_groups = attention.sub_network_query_groups
config.head_size = attention.sub_network_head_size
config.rope_n_elem = int(config.rotary_percentage * config.head_size)
lit_attention_small = init_lit_small_attention(config, lit_attention, attention)

out_lit_small = lit_attention_small(
input[:, :, : config.n_embd], mask=mask, cos=cos, sin=sin
)
# check that our sub-networks the same output as equally sized LitGPT attention layer
assert torch.all(out_lit_small == out_small)
122 changes: 122 additions & 0 deletions test/test_lora_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

import torch
from litgpt import Config as LitConfig
from litgpt.model import (
Block as LitBlock,
build_mask_cache,
build_rope_cache,
)

from whittle.lora.config import LoRAConfig as Config
from whittle.lora.lora_block import LoRABlock as Block


def test_block():
config = Config()
config.n_embd = 64
config.n_head = 8
config.n_query_groups = 4
config.head_size = 8
config.intermediate_size = 64 * 4
config.fix_head_size = False
config.mlp_class_name = "LLaMAMLP"
config.max_seq_len = 512
config.rotary_percentage = 0.25
config.rope_n_elem = int(config.rotary_percentage * config.head_size)
cos, sin = build_rope_cache(config.max_seq_len, n_elem=config.rope_n_elem)
litconfig = LitConfig()
litconfig.n_embd = 64
litconfig.n_head = 8
litconfig.n_query_groups = 4
litconfig.head_size = 8
litconfig.intermediate_size = 64 * 4
litconfig.fix_head_size = False
litconfig.mlp_class_name = "LLaMAMLP"
litconfig.max_seq_len = 512
litconfig.rotary_percentage = 0.25
litconfig.rope_n_elem = int(litconfig.rotary_percentage * litconfig.head_size)
block = Block(config, 0)
input = torch.rand(8, 512, 64)
mask = build_mask_cache(512)
block.attn.attn.linear.linear.weight.data = torch.ones_like(
block.attn.attn.linear.linear.weight.data
)
block.attn.attn.linear.linear.bias.data = torch.ones_like(
block.attn.attn.linear.linear.bias.data
)
block.attn.proj.linear.bias.data = torch.ones_like(block.attn.proj.linear.bias.data)
block.attn.proj.linear.weight.data = torch.ones_like(
block.attn.proj.linear.weight.data
)
block.mlp.fc_1.linear.weight.data = torch.ones_like(block.mlp.fc_1.linear.weight.data)
block.mlp.fc_1.linear.bias.data = torch.ones_like(block.mlp.fc_1.linear.bias.data)
block.mlp.fc_2.linear.weight.data = torch.ones_like(block.mlp.fc_2.linear.weight.data)
block.mlp.fc_2.linear.bias.data = torch.ones_like(block.mlp.fc_2.linear.bias.data)
block.mlp.proj.linear.weight.data = torch.ones_like(block.mlp.proj.linear.weight.data)
block.mlp.proj.linear.bias.data = torch.ones_like(block.mlp.proj.linear.bias.data)
block.reset_super_network()
out_large = block(input, cos, sin, mask)
assert out_large.shape == (8, 512, 64)
block.set_sub_network(
sub_network_n_embd=32,
sub_network_intermediate_size=32 * 4,
sub_network_num_heads=8,
sub_network_query_groups=config.n_query_groups // 2,
sub_network_head_size=32 // 4,
)
out_small = block(input[:, :, :32], cos, sin, mask)
assert out_small.shape == (8, 512, 32)

lit_block = LitBlock(litconfig, 0)
print(lit_block)
lit_block.attn.attn.weight.data = torch.ones_like(lit_block.attn.attn.weight.data)
lit_block.attn.attn.bias.data = torch.ones_like(lit_block.attn.attn.bias.data)
lit_block.attn.proj.bias.data = torch.ones_like(lit_block.attn.proj.bias.data)
lit_block.attn.proj.weight.data = torch.ones_like(lit_block.attn.proj.weight.data)
lit_block.mlp.fc_1.weight.data = torch.ones_like(lit_block.mlp.fc_1.weight.data)
lit_block.mlp.fc_1.bias.data = torch.ones_like(lit_block.mlp.fc_1.bias.data)
lit_block.mlp.fc_2.weight.data = torch.ones_like(lit_block.mlp.fc_2.weight.data)
lit_block.mlp.fc_2.bias.data = torch.ones_like(lit_block.mlp.fc_2.bias.data)
lit_block.mlp.proj.weight.data = torch.ones_like(lit_block.mlp.proj.weight.data)
lit_block.mlp.proj.bias.data = torch.ones_like(lit_block.mlp.proj.bias.data)
out_lit_large = lit_block(input, cos, sin, mask)
assert torch.all(out_lit_large == out_large)

litconfig.n_embd = 32
litconfig.n_head = 4
litconfig.n_query_groups = 2
litconfig.intermediate_size = 32 * 4
lit_block_small = LitBlock(litconfig, 0)
lit_block_small.attn.attn.weight.data = torch.ones_like(
lit_block_small.attn.attn.weight.data
)
lit_block_small.attn.attn.bias.data = torch.ones_like(
lit_block_small.attn.attn.bias.data
)
lit_block_small.attn.proj.bias.data = torch.ones_like(
lit_block_small.attn.proj.bias.data
)
lit_block_small.attn.proj.weight.data = torch.ones_like(
lit_block_small.attn.proj.weight.data
)
lit_block_small.mlp.fc_1.weight.data = torch.ones_like(
lit_block_small.mlp.fc_1.weight.data
)
lit_block_small.mlp.fc_1.bias.data = torch.ones_like(
lit_block_small.mlp.fc_1.bias.data
)
lit_block_small.mlp.fc_2.weight.data = torch.ones_like(
lit_block_small.mlp.fc_2.weight.data
)
lit_block_small.mlp.fc_2.bias.data = torch.ones_like(
lit_block_small.mlp.fc_2.bias.data
)
lit_block_small.mlp.proj.weight.data = torch.ones_like(
lit_block_small.mlp.proj.weight.data
)
lit_block_small.mlp.proj.bias.data = torch.ones_like(
lit_block_small.mlp.proj.bias.data
)
out_lit_small = lit_block_small(input[:, :, :32], cos, sin, mask)
assert torch.all(out_lit_small == out_small)
38 changes: 38 additions & 0 deletions test/test_lora_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

import torch

from whittle.lora.lora_embedding import LoRAEmbedding as Embedding


def test_embedding():
input_features = torch.randint(low=1, high=64, size=(4, 8))
emb = Embedding(64, 32)

out = emb(input_features)
assert out.shape == (4, 8, 32)
emb.set_sub_network(16)
out = emb(input_features)
assert out.shape == (4, 8, 16)
emb.set_sub_network(32)
out = emb(input_features)
assert out.shape == (4, 8, 32)

emb.embedding.weight.data = torch.randn_like(emb.embedding.weight.data)
emb.set_sub_network(16)
out_small = emb(input_features)
emb.set_sub_network(32)
out_large = emb(input_features)

small_layer = torch.nn.Embedding(64, 16)

small_layer.weight.data = emb.embedding.weight.data[:, :16]

out_small_layer = small_layer(input_features)

large_layer = torch.nn.Embedding(64, 32)
large_layer.weight.data = emb.embedding.weight.data[:, :32]
out_large_layer = large_layer(input_features)

assert torch.all(out_small == out_small_layer)
assert torch.all(out_large == out_large_layer)
Loading

0 comments on commit 3b800f3

Please sign in to comment.