diff --git a/test/test_api.py b/test/test_api.py index f462384..4cc5fed 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -28,35 +28,47 @@ def copy_subnetwork_weights(sub_network, super_network): """ # Ensure both networks are in evaluation mode to avoid any changes in weights during copying. - sub_network.eval() - super_network.eval() - - # Get the state dictionaries of both networks - sub_state_dict = sub_network.state_dict() - super_state_dict = super_network.state_dict() - - # Iterate over the subnetwork's state dictionary and copy weights to the supernetwork - for layer_name, sub_weight in sub_state_dict.items(): - if layer_name in super_state_dict: - super_weight = super_state_dict[layer_name] - - # Ensure the subnetwork's weight can fit into the supernetwork's weight tensor - if sub_weight.size() == super_weight.size(): - super_state_dict[layer_name] = sub_weight - else: - # Copy the sub_weight values into the corresponding part of super_weight - if len(sub_weight.shape) == 1: - super_weight[0 : sub_weight.shape[0]] = sub_weight - elif len(sub_weight.shape) == 2: - super_weight[0 : sub_weight.shape[0], 0 : sub_weight.shape[1]] = ( - sub_weight - ) - super_state_dict[layer_name] = super_weight - else: - raise KeyError(f"Layer {layer_name} not found in super-network.") - - # Load the modified state dictionary back into the supernetwork - super_network.load_state_dict(super_state_dict) + embd = super_network.sub_network_n_embd + intermediate = super_network.sub_network_intermediate_size + super_network.lm_head.weight.data[:, :embd] = sub_network.lm_head.weight.data + if super_network.lm_head.bias is not None: + super_network.lm_head.bias.data = sub_network.lm_head.bias.data + super_network.transformer.wte.weight.data[:, :embd] = ( + sub_network.transformer.wte.weight.data + ) + super_network.transformer.ln_f.weight.data[:embd] = ( + sub_network.transformer.ln_f.weight.data + ) + if super_network.transformer.ln_f.bias is not None: + super_network.transformer.ln_f.bias.data[:embd] = ( + sub_network.transformer.ln_f.bias.data + ) + for i, block_orig in enumerate(sub_network.transformer.h): + block = super_network.transformer.h[i] + block.attn.attn.weight.data[block.attn.qkv_indices, :][:, :embd] = ( + block_orig.attn.attn.weight.data + ) + if block.attn.attn.bias is not None: + block.attn.attn.bias.data[block.attn.qkv_indices] = ( + block_orig.attn.attn.bias.data + ) + block.attn.proj.weight.data[:, block.attn.proj_indices][:embd, :] = ( + block_orig.attn.proj.weight.data + ) + if block.attn.proj.bias is not None: + block.attn.proj.bias.data[:embd] = block_orig.attn.proj.bias.data + block.mlp.fc.weight.data[:intermediate, :embd] = block_orig.mlp.fc.weight.data + if block.mlp.fc.bias is not None: + block.mlp.fc.bias.data[:intermediate] = block_orig.mlp.fc.bias.data + block.mlp.proj.weight.data[:embd, :intermediate] = block_orig.mlp.proj.weight.data + if block.mlp.proj.bias is not None: + block.mlp.proj.bias.data[:embd] = block_orig.mlp.proj.bias.data + block.norm_1.weight.data[:embd] = block_orig.norm_1.weight.data + block.norm_2.weight.data[:embd] = block_orig.norm_2.weight.data + if block.norm_1.bias is not None: + block.norm_1.bias.data[:embd] = block_orig.norm_1.bias.data + if block.norm_2.bias is not None: + block.norm_2.bias.data[:embd] = block_orig.norm_2.bias.data return super_network @@ -344,8 +356,6 @@ def test_compare_litgpt(self, checkpoint_dir, checkpoint_dir_14m, out_dir): gpt_14m = GPT(config_14m).to(device) gpt_14m.name_or_path = "EleutherAI/pythia-14m" gpt_14m.load_state_dict(torch.load(str(checkpoint_dir_14m / "lit_model.pth"))) - gpt = copy_subnetwork_weights(gpt_14m, gpt) - gpt.max_seq_length = config_14m.block_size gpt.set_sub_network( sub_network_n_embd=config_14m.n_embd, sub_network_intermediate_size=config_14m.intermediate_size, @@ -354,6 +364,8 @@ def test_compare_litgpt(self, checkpoint_dir, checkpoint_dir_14m, out_dir): sub_network_query_groups=config_14m.n_query_groups, sub_network_head_size=config_14m.head_size, ) + gpt = copy_subnetwork_weights(gpt_14m, gpt) + gpt.max_seq_length = config_14m.block_size convert_and_evaluate( model=gpt, out_dir=out_dir, diff --git a/test/test_attention.py b/test/test_attention.py index 11a84d9..2d7469a 100644 --- a/test/test_attention.py +++ b/test/test_attention.py @@ -70,17 +70,30 @@ def init_lit_attention(config): return attention -def init_lit_small_attention(config, base_attention): +def init_lit_small_attention(config, base_attention, attention_super): attention = LitCausalSelfAttention(config, 2) torch.manual_seed(0) - slices = tuple(slice(0, s) for s in attention.attn.weight.data.size()) - attention.attn.weight.data = base_attention.attn.weight.data[slices] - slices = tuple(slice(0, s) for s in attention.attn.bias.data.size()) - attention.attn.bias.data = base_attention.attn.bias.data[slices] + slices = tuple(slice(0, s) for s in attention.attn.weight.data.size())[1] + qkv_indices = ( + attention_super.qkv_indices + if attention_super.qkv_indices is not None + else slice(0, attention.attn.weight.data.size()[0]) + ) + attention.attn.weight.data = base_attention.attn.weight.data[qkv_indices, :][ + :, 0 : attention.attn.weight.data.size()[1] + ] + attention.attn.bias.data = base_attention.attn.bias.data[qkv_indices] + proj_indices = ( + attention_super.proj_indices + if attention_super.proj_indices is not None + else slice(0, attention.proj.weight.data.size()[-1]) + ) slices = tuple(slice(0, s) for s in attention.proj.bias.data.size()) attention.proj.bias.data = base_attention.proj.bias.data[slices] - slices = tuple(slice(0, s) for s in attention.proj.weight.data.size()) - attention.proj.weight.data = base_attention.proj.weight.data[slices] + + attention.proj.weight.data = base_attention.proj.weight.data[ + 0 : attention.proj.weight.data.size()[0], : + ][:, proj_indices] return attention @@ -89,7 +102,7 @@ def test_attention(attention_config): config = attention_configs[attention_config]["config"] config.fix_head_size = attention_configs[attention_config]["fix_head_size"] if not config.fix_head_size: - config.head_size = config.n_embd // config.n_head + config.head_size = 32 config.max_seq_len = 512 config.rope_n_elem = int(config.rotary_percentage * config.head_size) @@ -108,20 +121,21 @@ def test_attention(attention_config): lit_attention = init_lit_attention(config) out_lit_large = lit_attention(input, mask=mask, cos=cos, sin=sin) if not config.fix_head_size: - sub_network_head_size = config.n_embd // (2 * config.n_head // 4) + sub_network_head_size = config.head_size // 2 else: sub_network_head_size = config.head_size if config.n_query_groups == 1: sub_network_query_groups = 1 - elif (config.n_head // 4) % config.n_query_groups == 0: - sub_network_query_groups = config.n_query_groups + sub_network_n_head = config.n_head // 4 + elif config.n_query_groups == config.n_head: + sub_network_n_head = config.n_head // 4 + sub_network_query_groups = sub_network_n_head else: - sub_network_query_groups = (config.n_head // 4) // ( - config.n_head // config.n_query_groups - ) + sub_network_query_groups = config.n_query_groups // 2 + sub_network_n_head = config.n_head // 2 attention.set_sub_network( sub_network_n_embd=config.n_embd // 2, - sub_network_n_head=config.n_head // 4, + sub_network_n_head=sub_network_n_head, sub_network_query_groups=sub_network_query_groups, sub_network_head_size=sub_network_head_size, ) @@ -135,18 +149,31 @@ def test_attention(attention_config): # check that our custom model produces the same output as LitGPT assert torch.all(out_lit_large == out_large) - config.n_embd = attention.sub_network_n_embd - config.n_head = attention.sub_network_n_head - config.n_query_groups = sub_network_query_groups - config.head_size = sub_network_head_size + if config.n_query_groups == config.n_head: + config.n_head = attention.sub_network_n_head + else: + config.n_head = ( + attention.sub_network_n_head // config.n_query_groups + ) * attention.sub_network_query_groups + config.n_query_groups = attention.sub_network_query_groups + config.head_size = attention.sub_network_head_size config.rope_n_elem = int(config.rotary_percentage * config.head_size) - - lit_attention_small = init_lit_small_attention(config, lit_attention) + lit_attention_small = init_lit_small_attention(config, lit_attention, attention) + if attention.qkv_indices is not None: + print(lit_attention_small.attn.weight.data.size()) + print( + attention.attn.weight.data[attention.qkv_indices, :][ + :, 0 : config.n_embd + ].size() + ) + assert torch.all( + lit_attention_small.attn.weight.data + == attention.attn.weight.data[attention.qkv_indices, :][:, 0 : config.n_embd] + ) out_lit_small = lit_attention_small( input[:, :, : config.n_embd], mask=mask, cos=cos, sin=sin ) - # check that our sub-networks the same output as equally sized LitGPT attention layer assert torch.all(out_lit_small == out_small) diff --git a/test/test_checkpoint_loading.py b/test/test_checkpoint_loading.py index db5030a..9d62ef9 100644 --- a/test/test_checkpoint_loading.py +++ b/test/test_checkpoint_loading.py @@ -8,7 +8,7 @@ from litgpt.model import GPT as LitGPT from litgpt.scripts.download import download_from_hub -from whittle.models.gpt.model import GPT as whittleGPT +from whittle.models.gpt.model import GPT as WhittleGPT @pytest.fixture(scope="session") @@ -34,17 +34,18 @@ def test_checkpoint_loading(checkpoint_dir): # litgpt download --repo_id stabilityai/stablelm-base-alpha-3b config = Config.from_file(str(checkpoint_dir / "model_config.yaml")) config.fix_head_size = False - model = whittleGPT(config) # .cuda() + model = WhittleGPT(config) # .cuda() model.load_state_dict(torch.load(str(checkpoint_dir / "lit_model.pth"))) # test output model.eval() - sample_intermediate_size = 4 * config.n_embd + sample_intermediate_size = config.intermediate_size model.set_sub_network( config.n_embd, sample_intermediate_size, config.n_head, config.n_layer, + config.n_query_groups, + config.head_size, ) - output_whittle = model(input_ids) assert torch.allclose(output_lit, output_whittle) diff --git a/test/test_lora_attention.py b/test/test_lora_attention.py new file mode 100644 index 0000000..59ac6b7 --- /dev/null +++ b/test/test_lora_attention.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import pytest +import torch +from litgpt.model import ( + CausalSelfAttention as LitCausalSelfAttention, + build_mask_cache, + build_rope_cache, +) + +from whittle.lora.config import LoRAConfig as Config +from whittle.lora.lora_attention import CausalSelfAttention + +attention_configs = { + "mha_fix_head_size_sliding": { + "config": Config( + n_embd=64, + n_head=16, + n_query_groups=16, + head_size=64, + sliding_window_size=256, + sliding_window_layer_placing="interleaved", + ), + "fix_head_size": True, + }, + "mha_fix_head_size": { + "config": Config(n_embd=64, n_head=16, n_query_groups=16, head_size=64), + "fix_head_size": True, + }, + "gqa_fix_head_size": { + "config": Config(n_embd=64, n_head=16, n_query_groups=2, head_size=64), + "fix_head_size": True, + }, + "mqa_fix_head_size": { + "config": Config(n_embd=64, n_head=16, n_query_groups=1, head_size=64), + "fix_head_size": True, + }, + "mha_flexible_head_size": { + "config": Config(n_embd=64, n_head=16, n_query_groups=16), + "fix_head_size": False, + }, + "gqa_flexible_head_size": { + "config": Config(n_embd=64, n_head=16, n_query_groups=2), + "fix_head_size": False, + }, + "mqa_flexible_head_size": { + "config": Config(n_embd=64, n_head=16, n_query_groups=1), + "fix_head_size": False, + }, +} + + +def init_attention(config): + attention = CausalSelfAttention(config, 2) + torch.manual_seed(0) + attention.attn.linear.linear.weight.data = torch.randn_like( + attention.attn.linear.linear.weight.data + ) + attention.attn.linear.linear.bias.data = torch.randn_like( + attention.attn.linear.linear.bias.data + ) + attention.proj.linear.bias.data = torch.randn_like(attention.proj.linear.bias.data) + attention.proj.linear.weight.data = torch.randn_like( + attention.proj.linear.weight.data + ) + return attention + + +def init_lit_attention(config): + attention = LitCausalSelfAttention(config, 2) + torch.manual_seed(0) + attention.attn.weight.data = torch.randn_like(attention.attn.weight.data) + attention.attn.bias.data = torch.randn_like(attention.attn.bias.data) + attention.proj.bias.data = torch.randn_like(attention.proj.bias.data) + attention.proj.weight.data = torch.randn_like(attention.proj.weight.data) + return attention + + +def init_lit_small_attention(config, base_attention, attention_super): + attention = LitCausalSelfAttention(config, 2) + torch.manual_seed(0) + slices = tuple(slice(0, s) for s in attention.attn.weight.data.size())[1] + qkv_indices = ( + attention_super.qkv_indices + if attention_super.qkv_indices is not None + else slice(0, attention.attn.weight.data.size()[0]) + ) + attention.attn.weight.data = base_attention.attn.weight.data[qkv_indices, :][ + :, 0 : attention.attn.weight.data.size()[1] + ] + attention.attn.bias.data = base_attention.attn.bias.data[qkv_indices] + proj_indices = ( + attention_super.proj_indices + if attention_super.proj_indices is not None + else slice(0, attention.proj.weight.data.size()[-1]) + ) + slices = tuple(slice(0, s) for s in attention.proj.bias.data.size()) + attention.proj.bias.data = base_attention.proj.bias.data[slices] + + attention.proj.weight.data = base_attention.proj.weight.data[ + 0 : attention.proj.weight.data.size()[0], : + ][:, proj_indices] + return attention + + +@pytest.mark.parametrize("attention_config", attention_configs.keys()) +def test_attention(attention_config): + config = attention_configs[attention_config]["config"] + config.fix_head_size = attention_configs[attention_config]["fix_head_size"] + if not config.fix_head_size: + config.head_size = 32 + config.max_seq_len = 512 + config.rope_n_elem = int(config.rotary_percentage * config.head_size) + + seq_len = config.max_seq_len + cos, sin = build_rope_cache(seq_len, n_elem=config.rope_n_elem) + cos = cos[:seq_len] + sin = sin[:seq_len] + input = torch.rand(8, seq_len, config.n_embd) + mask = build_mask_cache(seq_len) + + attention = init_attention(config) + out_large = attention(input, mask=mask, cos=cos, sin=sin) + + # check shape of super network attention + assert out_large.shape == (8, seq_len, config.n_embd) + lit_attention = init_lit_attention(config) + out_lit_large = lit_attention(input, mask=mask, cos=cos, sin=sin) + if not config.fix_head_size: + sub_network_head_size = config.head_size // 2 + else: + sub_network_head_size = config.head_size + if config.n_query_groups == 1: + sub_network_query_groups = 1 + sub_network_n_head = config.n_head // 4 + elif config.n_query_groups == config.n_head: + sub_network_n_head = config.n_head // 4 + sub_network_query_groups = sub_network_n_head + else: + sub_network_query_groups = config.n_query_groups // 2 + sub_network_n_head = config.n_head // 2 + attention.set_sub_network( + sub_network_n_embd=config.n_embd // 2, + sub_network_n_head=sub_network_n_head, + sub_network_query_groups=sub_network_query_groups, + sub_network_head_size=sub_network_head_size, + ) + cos, sin = build_rope_cache( + seq_len, n_elem=int(config.rotary_percentage * sub_network_head_size) + ) + out_small = attention(input[:, :, : config.n_embd // 2], mask=mask, cos=cos, sin=sin) + + # check shape of sub-network attention + assert out_small.shape == (8, seq_len, config.n_embd // 2) + + # check that our custom model produces the same output as LitGPT + assert torch.all(out_lit_large == out_large) + config.n_embd = attention.sub_network_n_embd + if config.n_query_groups == config.n_head: + config.n_head = attention.sub_network_n_head + else: + config.n_head = ( + attention.sub_network_n_head // config.n_query_groups + ) * attention.sub_network_query_groups + config.n_query_groups = attention.sub_network_query_groups + config.head_size = attention.sub_network_head_size + config.rope_n_elem = int(config.rotary_percentage * config.head_size) + lit_attention_small = init_lit_small_attention(config, lit_attention, attention) + + out_lit_small = lit_attention_small( + input[:, :, : config.n_embd], mask=mask, cos=cos, sin=sin + ) + # check that our sub-networks the same output as equally sized LitGPT attention layer + assert torch.all(out_lit_small == out_small) diff --git a/test/test_lora_block.py b/test/test_lora_block.py new file mode 100644 index 0000000..7da6bd7 --- /dev/null +++ b/test/test_lora_block.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import torch +from litgpt import Config as LitConfig +from litgpt.model import ( + Block as LitBlock, + build_mask_cache, + build_rope_cache, +) + +from whittle.lora.config import LoRAConfig as Config +from whittle.lora.lora_block import LoRABlock as Block + + +def test_block(): + config = Config() + config.n_embd = 64 + config.n_head = 8 + config.n_query_groups = 4 + config.head_size = 8 + config.intermediate_size = 64 * 4 + config.fix_head_size = False + config.mlp_class_name = "LLaMAMLP" + config.max_seq_len = 512 + config.rotary_percentage = 0.25 + config.rope_n_elem = int(config.rotary_percentage * config.head_size) + cos, sin = build_rope_cache(config.max_seq_len, n_elem=config.rope_n_elem) + litconfig = LitConfig() + litconfig.n_embd = 64 + litconfig.n_head = 8 + litconfig.n_query_groups = 4 + litconfig.head_size = 8 + litconfig.intermediate_size = 64 * 4 + litconfig.fix_head_size = False + litconfig.mlp_class_name = "LLaMAMLP" + litconfig.max_seq_len = 512 + litconfig.rotary_percentage = 0.25 + litconfig.rope_n_elem = int(litconfig.rotary_percentage * litconfig.head_size) + block = Block(config, 0) + input = torch.rand(8, 512, 64) + mask = build_mask_cache(512) + block.attn.attn.linear.linear.weight.data = torch.ones_like( + block.attn.attn.linear.linear.weight.data + ) + block.attn.attn.linear.linear.bias.data = torch.ones_like( + block.attn.attn.linear.linear.bias.data + ) + block.attn.proj.linear.bias.data = torch.ones_like(block.attn.proj.linear.bias.data) + block.attn.proj.linear.weight.data = torch.ones_like( + block.attn.proj.linear.weight.data + ) + block.mlp.fc_1.linear.weight.data = torch.ones_like(block.mlp.fc_1.linear.weight.data) + block.mlp.fc_1.linear.bias.data = torch.ones_like(block.mlp.fc_1.linear.bias.data) + block.mlp.fc_2.linear.weight.data = torch.ones_like(block.mlp.fc_2.linear.weight.data) + block.mlp.fc_2.linear.bias.data = torch.ones_like(block.mlp.fc_2.linear.bias.data) + block.mlp.proj.linear.weight.data = torch.ones_like(block.mlp.proj.linear.weight.data) + block.mlp.proj.linear.bias.data = torch.ones_like(block.mlp.proj.linear.bias.data) + block.reset_super_network() + out_large = block(input, cos, sin, mask) + assert out_large.shape == (8, 512, 64) + block.set_sub_network( + sub_network_n_embd=32, + sub_network_intermediate_size=32 * 4, + sub_network_num_heads=8, + sub_network_query_groups=config.n_query_groups // 2, + sub_network_head_size=32 // 4, + ) + out_small = block(input[:, :, :32], cos, sin, mask) + assert out_small.shape == (8, 512, 32) + + lit_block = LitBlock(litconfig, 0) + print(lit_block) + lit_block.attn.attn.weight.data = torch.ones_like(lit_block.attn.attn.weight.data) + lit_block.attn.attn.bias.data = torch.ones_like(lit_block.attn.attn.bias.data) + lit_block.attn.proj.bias.data = torch.ones_like(lit_block.attn.proj.bias.data) + lit_block.attn.proj.weight.data = torch.ones_like(lit_block.attn.proj.weight.data) + lit_block.mlp.fc_1.weight.data = torch.ones_like(lit_block.mlp.fc_1.weight.data) + lit_block.mlp.fc_1.bias.data = torch.ones_like(lit_block.mlp.fc_1.bias.data) + lit_block.mlp.fc_2.weight.data = torch.ones_like(lit_block.mlp.fc_2.weight.data) + lit_block.mlp.fc_2.bias.data = torch.ones_like(lit_block.mlp.fc_2.bias.data) + lit_block.mlp.proj.weight.data = torch.ones_like(lit_block.mlp.proj.weight.data) + lit_block.mlp.proj.bias.data = torch.ones_like(lit_block.mlp.proj.bias.data) + out_lit_large = lit_block(input, cos, sin, mask) + assert torch.all(out_lit_large == out_large) + + litconfig.n_embd = 32 + litconfig.n_head = 4 + litconfig.n_query_groups = 2 + litconfig.intermediate_size = 32 * 4 + lit_block_small = LitBlock(litconfig, 0) + lit_block_small.attn.attn.weight.data = torch.ones_like( + lit_block_small.attn.attn.weight.data + ) + lit_block_small.attn.attn.bias.data = torch.ones_like( + lit_block_small.attn.attn.bias.data + ) + lit_block_small.attn.proj.bias.data = torch.ones_like( + lit_block_small.attn.proj.bias.data + ) + lit_block_small.attn.proj.weight.data = torch.ones_like( + lit_block_small.attn.proj.weight.data + ) + lit_block_small.mlp.fc_1.weight.data = torch.ones_like( + lit_block_small.mlp.fc_1.weight.data + ) + lit_block_small.mlp.fc_1.bias.data = torch.ones_like( + lit_block_small.mlp.fc_1.bias.data + ) + lit_block_small.mlp.fc_2.weight.data = torch.ones_like( + lit_block_small.mlp.fc_2.weight.data + ) + lit_block_small.mlp.fc_2.bias.data = torch.ones_like( + lit_block_small.mlp.fc_2.bias.data + ) + lit_block_small.mlp.proj.weight.data = torch.ones_like( + lit_block_small.mlp.proj.weight.data + ) + lit_block_small.mlp.proj.bias.data = torch.ones_like( + lit_block_small.mlp.proj.bias.data + ) + out_lit_small = lit_block_small(input[:, :, :32], cos, sin, mask) + assert torch.all(out_lit_small == out_small) diff --git a/test/test_lora_embedding.py b/test/test_lora_embedding.py new file mode 100644 index 0000000..0b0d2f7 --- /dev/null +++ b/test/test_lora_embedding.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import torch + +from whittle.lora.lora_embedding import LoRAEmbedding as Embedding + + +def test_embedding(): + input_features = torch.randint(low=1, high=64, size=(4, 8)) + emb = Embedding(64, 32) + + out = emb(input_features) + assert out.shape == (4, 8, 32) + emb.set_sub_network(16) + out = emb(input_features) + assert out.shape == (4, 8, 16) + emb.set_sub_network(32) + out = emb(input_features) + assert out.shape == (4, 8, 32) + + emb.embedding.weight.data = torch.randn_like(emb.embedding.weight.data) + emb.set_sub_network(16) + out_small = emb(input_features) + emb.set_sub_network(32) + out_large = emb(input_features) + + small_layer = torch.nn.Embedding(64, 16) + + small_layer.weight.data = emb.embedding.weight.data[:, :16] + + out_small_layer = small_layer(input_features) + + large_layer = torch.nn.Embedding(64, 32) + large_layer.weight.data = emb.embedding.weight.data[:, :32] + out_large_layer = large_layer(input_features) + + assert torch.all(out_small == out_small_layer) + assert torch.all(out_large == out_large_layer) diff --git a/test/test_lora_gpt.py b/test/test_lora_gpt.py new file mode 100644 index 0000000..a726dc7 --- /dev/null +++ b/test/test_lora_gpt.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +import torch +from litgpt import Config as LitConfig +from litgpt.model import GPT as LitGPT + +from whittle.lora.config import LoRAConfig as Config +from whittle.lora.lora_gpt import GPT + + +def test_gpt(): + torch.manual_seed(0) + config = Config() + config.padded_vocab_size = 128 + config.n_embd = 64 + config.intermediate_size = 64 * 4 + config.n_head = 8 + config.n_query_groups = 4 + config.head_size = 8 + config.n_layer = 2 + config.block_size = 128 + config.norm_class_name = "RMSNorm" + config.mlp_class_name = "LLaMAMLP" + config.rope_n_elem = int(config.rotary_percentage * config.head_size) + config.norm_eps = 1e-5 + config.lm_head_bias = True + config.fix_head_size = False + litconfig = LitConfig() + litconfig.padded_vocab_size = 128 + litconfig.n_embd = 64 + litconfig.intermediate_size = 64 * 4 + litconfig.n_head = 8 + litconfig.n_query_groups = 4 + litconfig.head_size = 8 + litconfig.n_layer = 2 + litconfig.block_size = 128 + litconfig.norm_class_name = "RMSNorm" + litconfig.mlp_class_name = "LLaMAMLP" + litconfig.rope_n_elem = int(litconfig.rotary_percentage * litconfig.head_size) + litconfig.norm_eps = 1e-5 + litconfig.lm_head_bias = True + litconfig.fix_head_size = False + gpt = GPT(config) + gpt.transformer.wte.embedding.weight.data = torch.randn_like( + gpt.transformer.wte.embedding.weight.data + ) + gpt.lm_head.linear.weight.data = torch.randn_like(gpt.lm_head.linear.weight.data) + gpt.lm_head.linear.bias.data = torch.randn_like(gpt.lm_head.linear.bias.data) + gpt.transformer.ln_f.weight.data = torch.randn_like(gpt.transformer.ln_f.weight.data) + + for block in gpt.transformer.h: + block.attn.attn.linear.linear.weight.data = torch.randn_like( + block.attn.attn.linear.linear.weight.data + ) + block.attn.attn.linear.linear.bias.data = torch.randn_like( + block.attn.attn.linear.linear.bias.data + ) + block.attn.proj.linear.bias.data = torch.randn_like( + block.attn.proj.linear.bias.data + ) + block.attn.proj.linear.weight.data = torch.randn_like( + block.attn.proj.linear.weight.data + ) + block.mlp.fc_1.linear.weight.data = torch.randn_like( + block.mlp.fc_1.linear.weight.data + ) + block.mlp.fc_1.linear.bias.data = torch.randn_like( + block.mlp.fc_1.linear.bias.data + ) + block.mlp.fc_2.linear.weight.data = torch.randn_like( + block.mlp.fc_2.linear.weight.data + ) + block.mlp.fc_2.linear.bias.data = torch.randn_like( + block.mlp.fc_2.linear.bias.data + ) + block.mlp.proj.linear.weight.data = torch.randn_like( + block.mlp.proj.linear.weight.data + ) + block.mlp.proj.linear.bias.data = torch.randn_like( + block.mlp.proj.linear.bias.data + ) + block.norm_1.weight.data = torch.randn_like(block.norm_1.weight.data) + block.norm_2.weight.data = torch.randn_like(block.norm_2.weight.data) + + gpt.reset_super_network() + input = torch.randint(0, 64, (1, 64)) + out_large = gpt(input) + assert out_large.shape == (1, 64, 128) + + lit_gpt = LitGPT(litconfig) + lit_gpt.lm_head.weight.data = gpt.lm_head.linear.weight.data + lit_gpt.lm_head.bias.data = gpt.lm_head.linear.bias.data + lit_gpt.transformer.wte.weight.data = gpt.transformer.wte.embedding.weight.data + lit_gpt.transformer.ln_f.weight.data = gpt.transformer.ln_f.weight.data + for i, block in enumerate(lit_gpt.transformer.h): + block_orig = gpt.transformer.h[i] + block.attn.attn.weight.data = block_orig.attn.attn.linear.linear.weight.data + block.attn.attn.bias.data = block_orig.attn.attn.linear.linear.bias.data + block.attn.proj.bias.data = block_orig.attn.proj.linear.bias.data + block.attn.proj.weight.data = block_orig.attn.proj.linear.weight.data + block.mlp.fc_1.weight.data = block_orig.mlp.fc_1.linear.weight.data + block.mlp.fc_1.bias.data = block_orig.mlp.fc_1.linear.bias.data + block.mlp.fc_2.weight.data = block_orig.mlp.fc_2.linear.weight.data + block.mlp.fc_2.bias.data = block_orig.mlp.fc_2.linear.bias.data + block.mlp.proj.weight.data = block_orig.mlp.proj.linear.weight.data + block.mlp.proj.bias.data = block_orig.mlp.proj.linear.bias.data + block.norm_1.weight.data = block_orig.norm_1.weight.data + block.norm_2.weight.data = block_orig.norm_2.weight.data + + out_lit_large = lit_gpt(input) + assert torch.allclose(out_lit_large, out_large, atol=1e-3) + gpt.set_sub_network( + sub_network_n_embd=32, + sub_network_intermediate_size=32 * 4, + sub_network_num_heads=4, + sub_network_n_layers=1, + sub_network_query_groups=2, + sub_network_head_size=config.head_size, + ) + out_small = gpt(input) + assert out_small.shape == (1, 64, 128) + litconfig.n_embd = 32 + litconfig.n_head = 2 + litconfig.n_query_groups = 2 + litconfig.intermediate_size = 32 * 4 + litconfig.n_layer = 1 + print(config) + lit_gpt_small = LitGPT(litconfig) + lit_gpt_small.lm_head.weight.data = gpt.lm_head.linear.weight.data[ + : gpt.lm_head.sub_network_out_features, : gpt.lm_head.sub_network_in_features + ] + lit_gpt_small.lm_head.bias.data = gpt.lm_head.linear.bias.data[:] + lit_gpt_small.transformer.wte.weight.data = gpt.transformer.wte.embedding.weight.data[ + :, : gpt.transformer.wte.sub_network_embedding_dim + ] + lit_gpt_small.transformer.ln_f.weight.data = gpt.transformer.ln_f.weight.data[ + : gpt.transformer.ln_f.sub_network_in_features + ] + + for i, block in enumerate(lit_gpt_small.transformer.h): + block_orig = gpt.transformer.h[i] + if block_orig.attn.qkv_indices is not None: + block.attn.attn.weight.data = block_orig.attn.attn.linear.linear.weight.data[ + block_orig.attn.qkv_indices, + : block_orig.attn.attn.sub_network_in_features, + ] + block.attn.attn.bias.data = block_orig.attn.attn.linear.linear.bias.data[ + block_orig.attn.qkv_indices + ] + print(torch.tensor(block_orig.attn.qkv_indices).shape) + else: + block.attn.attn.weight.data = block_orig.attn.attn.linear.linear.weight.data[ + : block_orig.attn.attn.sub_network_out_features, + : block_orig.attn.attn.sub_network_in_features, + ] + block.attn.attn.bias.data = block_orig.attn.attn.linear.linear.bias.data[ + : block_orig.attn.attn.sub_network_out_features + ] + if block_orig.attn.proj_indices is not None: + block.attn.proj.weight.data = block_orig.attn.proj.linear.weight.data[ + : block_orig.attn.proj.sub_network_out_features, + block_orig.attn.proj_indices, + ] + block.attn.proj.bias.data = block_orig.attn.proj.linear.bias.data[ + : block_orig.attn.proj.sub_network_out_features + ] + else: + block.attn.proj.bias.data = block_orig.attn.proj.linear.bias.data[ + : block_orig.attn.proj.sub_network_out_features + ] + block.attn.proj.weight.data = block_orig.attn.proj.linear.weight.data[ + : block_orig.attn.proj.sub_network_out_features, + : block_orig.attn.proj.sub_network_in_features, + ] + block.mlp.fc_1.weight.data = block_orig.mlp.fc_1.linear.weight.data[ + : block_orig.mlp.fc_1.sub_network_out_features, + : block_orig.mlp.fc_1.sub_network_in_features, + ] + block.mlp.fc_1.bias.data = block_orig.mlp.fc_1.linear.bias.data[ + : block_orig.mlp.fc_1.sub_network_out_features + ] + block.mlp.fc_2.weight.data = block_orig.mlp.fc_2.linear.weight.data[ + : block_orig.mlp.fc_2.sub_network_out_features, + : block_orig.mlp.fc_2.sub_network_in_features, + ] + block.mlp.fc_2.bias.data = block_orig.mlp.fc_2.linear.bias.data[ + : block_orig.mlp.fc_2.sub_network_out_features + ] + block.mlp.proj.weight.data = block_orig.mlp.proj.linear.weight.data[ + : block_orig.mlp.proj.sub_network_out_features, + : block_orig.mlp.proj.sub_network_in_features, + ] + block.mlp.proj.bias.data = block_orig.mlp.proj.linear.bias.data[ + : block_orig.mlp.proj.sub_network_out_features + ] + block.norm_1.weight.data = block_orig.norm_1.weight.data[ + : block_orig.norm_1.sub_network_in_features + ] + block.norm_2.weight.data = block_orig.norm_2.weight.data[ + : block_orig.norm_2.sub_network_in_features + ] + out_lit_small = lit_gpt_small(input) + assert torch.allclose(out_lit_small, out_small, atol=1e-3) + + +def copy_weights(model_source, model_target): + for name, param_source in model_source.named_parameters(): + if "linear." in name: + target_name = name.replace("linear.", "") + elif "embedding." in name: + target_name = name.replace("embedding.", "") + else: + target_name = name + print(name, target_name) + if target_name in model_target.state_dict(): + param_target = model_target.state_dict()[target_name] + if param_source.shape == param_target.shape: + param_source.data.copy_(param_target.data) + print(f"Copying {name} to {target_name}") + + +def test_llama_3_1(): + config_llama = Config.from_name( + "Llama-3-8B", + n_layer=2, + n_embd=32, + intermediate_size=86, + padded_vocab_size=10000, + ) + config_llama.fix_head_size = True + config_llama_lit = LitConfig.from_name( + "Llama-3-8B", + n_layer=2, + n_embd=32, + intermediate_size=86, + padded_vocab_size=10000, + ) + lit_model = LitGPT(config_llama_lit) + whittle_model = GPT(config_llama) + copy_weights(whittle_model, lit_model) + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32) + whittle_out = whittle_model(x) + lit_out = lit_model(x) + assert torch.allclose(whittle_out, lit_out, atol=1e-3) + + +def test_llama_3_2(): + config_llama = Config.from_name( + "Llama-3.2-1B", + n_layer=2, + n_embd=32, + intermediate_size=86, + padded_vocab_size=10000, + ) + config_llama.fix_head_size = True + config_llama_lit = LitConfig.from_name( + "Llama-3.2-1B", + n_layer=2, + n_embd=32, + intermediate_size=86, + padded_vocab_size=10000, + ) + lit_model = LitGPT(config_llama_lit) + whittle_model = GPT(config_llama) + copy_weights(whittle_model, lit_model) + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32) + whittle_out = whittle_model(x) + lit_out = lit_model(x) + assert torch.allclose(whittle_out, lit_out, atol=1e-3) + + +def test_gemma_2(): + config_gemma = Config.from_name( + "gemma-2-9b", + block_size=6, + sliding_window_size=3, + n_layer=2, + n_embd=32, + intermediate_size=86, + ) + config_gemma.head_size = 256 + config_gemma.fix_head_size = True + config_gemma_lit = LitConfig.from_name( + "gemma-2-9b", + block_size=6, + sliding_window_size=3, + n_layer=2, + n_embd=32, + intermediate_size=86, + ) + print(config_gemma) + lit_model = LitGPT(config_gemma_lit) + whittle_model = GPT(config_gemma) + copy_weights(whittle_model, lit_model) + x = torch.tensor([[9856, 23, 491, 1536, 304, 1234]], dtype=torch.int32) + whittle_out = whittle_model(x) + lit_out = lit_model(x) + assert torch.allclose(whittle_out, lit_out, atol=1e-3) diff --git a/test/test_lora_linear.py b/test/test_lora_linear.py new file mode 100644 index 0000000..12c567d --- /dev/null +++ b/test/test_lora_linear.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import torch + +from whittle.lora.lora_linear import LoRALinear as Linear + + +def test_linear(): + input_features = torch.rand(8, 64) + linear = Linear(64, 32, bias=True) + linear.reset_super_network() + out = linear(input_features) + assert out.shape == (8, 32) + linear.set_sub_network(sub_network_in_features=64, sub_network_out_features=16) + out = linear(input_features) + assert out.shape == (8, 16) + linear.set_sub_network(sub_network_in_features=64, sub_network_out_features=32) + out = linear(input_features) + assert out.shape == (8, 32) + + input_small = torch.rand(8, 16) + linear.linear.weight.data = torch.randn_like(linear.linear.weight.data) + linear.linear.bias.data = torch.randn_like(linear.linear.bias.data) + linear.set_sub_network(sub_network_in_features=64, sub_network_out_features=16) + out_small = linear(input_features) + linear.set_sub_network(sub_network_in_features=64, sub_network_out_features=32) + out_large = linear(input_features) + linear.set_sub_network(sub_network_in_features=16, sub_network_out_features=32) + out_small_large = linear(input_small) + + small_layer = torch.nn.Linear(64, 16, bias=True) + + small_layer.weight.data = linear.linear.weight.data[:16, :] + small_layer.bias.data = linear.linear.bias.data[:16] + out_small_layer = small_layer(input_features) + + large_layer = torch.nn.Linear(64, 32) + large_layer.weight.data = linear.linear.weight.data + large_layer.bias.data = linear.linear.bias.data + out_large_layer = large_layer(input_features) + + small_large_layer = torch.nn.Linear(16, 32) + small_large_layer.weight.data = linear.linear.weight.data[:32, :16] + small_large_layer.bias.data = linear.linear.bias.data[:32] + out_small_large_layer = small_large_layer(input_small) + + assert torch.all(out_small == out_small_layer) + assert torch.all(out_large == out_large_layer) + assert torch.all(out_small_large == out_small_large_layer) diff --git a/test/test_lora_mlp.py b/test/test_lora_mlp.py new file mode 100644 index 0000000..227bfc8 --- /dev/null +++ b/test/test_lora_mlp.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +import torch +from litgpt.model import ( + GemmaMLP as LitGemmaMLP, + GptNeoxMLP as LitGptNeoxMLP, + LLaMAMLP as LitLLaMAMLP, +) + +from whittle.lora.config import LoRAConfig as Config +from whittle.lora.lora_mlps import ( + LoRAGemmaMLP as GemmaMLP, + LoRAGptNeoxMLP as GptNeoxMLP, + LoRALLaMAMLP as LLaMAMLP, +) + + +def test_GptNeoxMLP(): + config = Config() + input = torch.rand(8, 64) + # update config + config.n_embd = 64 + config.intermediate_size = 64 * 4 + gpt_neox_mlp = GptNeoxMLP(config) + # init weights and biases + gpt_neox_mlp.fc.linear.weight.data = torch.randn_like( + gpt_neox_mlp.fc.linear.weight.data + ) + gpt_neox_mlp.fc.linear.bias.data = torch.randn_like(gpt_neox_mlp.fc.linear.bias.data) + gpt_neox_mlp.proj.linear.weight.data = torch.randn_like( + gpt_neox_mlp.proj.linear.weight.data + ) + gpt_neox_mlp.proj.linear.bias.data = torch.randn_like( + gpt_neox_mlp.proj.linear.bias.data + ) + gpt_neox_mlp.reset_super_network() + out_large = gpt_neox_mlp(input) + assert out_large.shape == (8, 64) + gpt_neox_mlp.set_sub_network( + sub_network_n_embd=32, sub_network_intermediate_size=32 * 4 + ) + out_small = gpt_neox_mlp(input[:8, :32]) + assert out_small.shape == (8, 32) + + litgpt_neox_mlp_large = LitGptNeoxMLP(config) + litgpt_neox_mlp_large.fc.weight.data = gpt_neox_mlp.fc.linear.weight.data + litgpt_neox_mlp_large.fc.bias.data = gpt_neox_mlp.fc.linear.bias.data + litgpt_neox_mlp_large.proj.weight.data = gpt_neox_mlp.proj.linear.weight.data + litgpt_neox_mlp_large.proj.bias.data = gpt_neox_mlp.proj.linear.bias.data + out_large_lit = litgpt_neox_mlp_large(input) + config.n_embd = 32 + config.intermediate_size = 32 * 4 + litgpt_neox_mlp_small = LitGptNeoxMLP(config) + litgpt_neox_mlp_small.fc.weight.data = gpt_neox_mlp.fc.linear.weight.data[ + : config.intermediate_size, : config.n_embd + ] + litgpt_neox_mlp_small.fc.bias.data = gpt_neox_mlp.fc.linear.bias.data[ + : config.intermediate_size + ] + litgpt_neox_mlp_small.proj.weight.data = gpt_neox_mlp.proj.linear.weight.data[ + : config.n_embd, : config.intermediate_size + ] + litgpt_neox_mlp_small.proj.bias.data = gpt_neox_mlp.proj.linear.bias.data[ + : config.n_embd + ] + out_small_lit = litgpt_neox_mlp_small(input[:8, :32]) + assert torch.all(out_small == out_small_lit) + assert torch.all(out_large == out_large_lit) + + +def test_LLaMAMLP(): + config = Config() + input = torch.rand(8, 64) + # update config + config.n_embd = 64 + config.intermediate_size = 64 * 4 + llama_mlp = LLaMAMLP(config) + # init weights and biases + llama_mlp.fc_1.linear.weight.data = torch.randn_like( + llama_mlp.fc_1.linear.weight.data + ) + llama_mlp.fc_1.linear.bias.data = torch.randn_like(llama_mlp.fc_1.linear.bias.data) + llama_mlp.fc_2.linear.weight.data = torch.randn_like( + llama_mlp.fc_2.linear.weight.data + ) + llama_mlp.fc_2.linear.bias.data = torch.randn_like(llama_mlp.fc_2.linear.bias.data) + llama_mlp.proj.linear.weight.data = torch.randn_like( + llama_mlp.proj.linear.weight.data + ) + llama_mlp.proj.linear.bias.data = torch.randn_like(llama_mlp.proj.linear.bias.data) + llama_mlp.reset_super_network() + out_large = llama_mlp(input) + assert out_large.shape == (8, 64) + llama_mlp.set_sub_network(sub_network_n_embd=32, sub_network_intermediate_size=32 * 4) + out_small = llama_mlp(input[:8, :32]) + assert out_small.shape == (8, 32) + + litllama_mlp_large = LitLLaMAMLP(config) + litllama_mlp_large.fc_1.weight.data = llama_mlp.fc_1.linear.weight.data + litllama_mlp_large.fc_1.bias.data = llama_mlp.fc_1.linear.bias.data + litllama_mlp_large.fc_2.weight.data = llama_mlp.fc_2.linear.weight.data + litllama_mlp_large.fc_2.bias.data = llama_mlp.fc_2.linear.bias.data + litllama_mlp_large.proj.weight.data = llama_mlp.proj.linear.weight.data + litllama_mlp_large.proj.bias.data = llama_mlp.proj.linear.bias.data + out_large_lit = litllama_mlp_large(input) + config.n_embd = 32 + config.intermediate_size = 32 * 4 + litllama_mlp_small = LitLLaMAMLP(config) + litllama_mlp_small.fc_1.weight.data = llama_mlp.fc_1.linear.weight.data[ + : config.intermediate_size, : config.n_embd + ] + litllama_mlp_small.fc_1.bias.data = llama_mlp.fc_1.linear.bias.data[ + : config.intermediate_size + ] + litllama_mlp_small.fc_2.weight.data = llama_mlp.fc_2.linear.weight.data[ + : config.intermediate_size, : config.n_embd + ] + litllama_mlp_small.fc_2.bias.data = llama_mlp.fc_2.linear.bias.data[ + : config.intermediate_size + ] + litllama_mlp_small.proj.weight.data = llama_mlp.proj.linear.weight.data[ + : config.n_embd, : config.intermediate_size + ] + litllama_mlp_small.proj.bias.data = llama_mlp.proj.linear.bias.data[: config.n_embd] + out_small_lit = litllama_mlp_small(input[:8, :32]) + assert torch.all(out_small == out_small_lit) + assert torch.all(out_large == out_large_lit) + + +def test_GemmaMLP(): + config = Config() + input = torch.rand(8, 64) + # update config + config.n_embd = 64 + config.intermediate_size = 64 * 4 + gemma_mlp = GemmaMLP(config) + # init weights and biases + gemma_mlp.fc_1.linear.weight.data = torch.randn_like( + gemma_mlp.fc_1.linear.weight.data + ) + gemma_mlp.fc_1.linear.bias.data = torch.randn_like(gemma_mlp.fc_1.linear.bias.data) + gemma_mlp.fc_2.linear.weight.data = torch.randn_like( + gemma_mlp.fc_2.linear.weight.data + ) + gemma_mlp.fc_2.linear.bias.data = torch.randn_like(gemma_mlp.fc_2.linear.bias.data) + gemma_mlp.proj.linear.weight.data = torch.randn_like( + gemma_mlp.proj.linear.weight.data + ) + gemma_mlp.proj.linear.bias.data = torch.randn_like(gemma_mlp.proj.linear.bias.data) + gemma_mlp.reset_super_network() + out_large = gemma_mlp(input) + assert out_large.shape == (8, 64) + gemma_mlp.set_sub_network(sub_network_n_embd=32, sub_network_intermediate_size=32 * 4) + out_small = gemma_mlp(input[:8, :32]) + assert out_small.shape == (8, 32) + + litgemma_mlp_large = LitGemmaMLP(config) + litgemma_mlp_large.fc_1.weight.data = gemma_mlp.fc_1.linear.weight.data + litgemma_mlp_large.fc_1.bias.data = gemma_mlp.fc_1.linear.bias.data + litgemma_mlp_large.fc_2.weight.data = gemma_mlp.fc_2.linear.weight.data + litgemma_mlp_large.fc_2.bias.data = gemma_mlp.fc_2.linear.bias.data + litgemma_mlp_large.proj.weight.data = gemma_mlp.proj.linear.weight.data + litgemma_mlp_large.proj.bias.data = gemma_mlp.proj.linear.bias.data + out_large_lit = litgemma_mlp_large(input) + config.n_embd = 32 + config.intermediate_size = 32 * 4 + litgemma_mlp_small = LitGemmaMLP(config) + litgemma_mlp_small.fc_1.weight.data = gemma_mlp.fc_1.linear.weight.data[ + : config.intermediate_size, : config.n_embd + ] + litgemma_mlp_small.fc_1.bias.data = gemma_mlp.fc_1.linear.bias.data[ + : config.intermediate_size + ] + litgemma_mlp_small.fc_2.weight.data = gemma_mlp.fc_2.linear.weight.data[ + : config.intermediate_size, : config.n_embd + ] + litgemma_mlp_small.fc_2.bias.data = gemma_mlp.fc_2.linear.bias.data[ + : config.intermediate_size + ] + litgemma_mlp_small.proj.weight.data = gemma_mlp.proj.linear.weight.data[ + : config.n_embd, : config.intermediate_size + ] + litgemma_mlp_small.proj.bias.data = gemma_mlp.proj.linear.bias.data[: config.n_embd] + out_small_lit = litgemma_mlp_small(input[:8, :32]) + assert torch.all(out_small == out_small_lit) + assert torch.all(out_large == out_large_lit) diff --git a/test/test_model.py b/test/test_model.py index 8ce25d1..c2a0c06 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -77,14 +77,16 @@ def test_gpt(): sub_network_num_heads=4, sub_network_n_layers=1, sub_network_query_groups=2, + sub_network_head_size=config.head_size, ) out_small = gpt(input) assert out_small.shape == (1, 64, 128) config.n_embd = 32 - config.n_head = 4 + config.n_head = 2 config.n_query_groups = 2 config.intermediate_size = 32 * 4 config.n_layer = 1 + print(config) lit_gpt_small = LitGPT(config) lit_gpt_small.lm_head.weight.data = gpt.lm_head.weight.data[ : gpt.lm_head.sub_network_out_features, : gpt.lm_head.sub_network_in_features @@ -99,20 +101,39 @@ def test_gpt(): for i, block in enumerate(lit_gpt_small.transformer.h): block_orig = gpt.transformer.h[i] - block.attn.attn.weight.data = block_orig.attn.attn.weight.data[ - : block_orig.attn.attn.sub_network_out_features, - : block_orig.attn.attn.sub_network_in_features, - ] - block.attn.attn.bias.data = block_orig.attn.attn.bias.data[ - : block_orig.attn.attn.sub_network_out_features - ] - block.attn.proj.bias.data = block_orig.attn.proj.bias.data[ - : block_orig.attn.proj.sub_network_out_features - ] - block.attn.proj.weight.data = block_orig.attn.proj.weight.data[ - : block_orig.attn.proj.sub_network_out_features, - : block_orig.attn.proj.sub_network_in_features, - ] + if block_orig.attn.qkv_indices is not None: + block.attn.attn.weight.data = block_orig.attn.attn.weight.data[ + block_orig.attn.qkv_indices, + : block_orig.attn.attn.sub_network_in_features, + ] + block.attn.attn.bias.data = block_orig.attn.attn.bias.data[ + block_orig.attn.qkv_indices + ] + print(torch.tensor(block_orig.attn.qkv_indices).shape) + else: + block.attn.attn.weight.data = block_orig.attn.attn.weight.data[ + : block_orig.attn.attn.sub_network_out_features, + : block_orig.attn.attn.sub_network_in_features, + ] + block.attn.attn.bias.data = block_orig.attn.attn.bias.data[ + : block_orig.attn.attn.sub_network_out_features + ] + if block_orig.attn.proj_indices is not None: + block.attn.proj.weight.data = block_orig.attn.proj.weight.data[ + : block_orig.attn.proj.sub_network_out_features, + block_orig.attn.proj_indices, + ] + block.attn.proj.bias.data = block_orig.attn.proj.bias.data[ + : block_orig.attn.proj.sub_network_out_features + ] + else: + block.attn.proj.bias.data = block_orig.attn.proj.bias.data[ + : block_orig.attn.proj.sub_network_out_features + ] + block.attn.proj.weight.data = block_orig.attn.proj.weight.data[ + : block_orig.attn.proj.sub_network_out_features, + : block_orig.attn.proj.sub_network_in_features, + ] block.mlp.fc_1.weight.data = block_orig.mlp.fc_1.weight.data[ : block_orig.mlp.fc_1.sub_network_out_features, : block_orig.mlp.fc_1.sub_network_in_features, @@ -197,7 +218,6 @@ def test_gemma_2(): intermediate_size=86, ) config_gemma.fix_head_size = True - print(config_gemma) lit_model = LitGPT(config_gemma) whittle_model = GPT(config_gemma) copy_weights(lit_model, whittle_model) diff --git a/test/test_parameters.py b/test/test_parameters.py index 600a737..cbc3933 100644 --- a/test/test_parameters.py +++ b/test/test_parameters.py @@ -27,6 +27,7 @@ def test_compute_parameters_sub_network(mlp_type, norm_type): config.n_embd = 64 config.intermediate_size = 64 * 4 config.n_head = 8 + config.bias = False config.n_query_groups = 8 config.head_size = 8 config.n_layer = 2 @@ -51,11 +52,13 @@ def test_compute_parameters_sub_network(mlp_type, norm_type): sub_network_intermediate_size=config.intermediate_size, sub_network_num_heads=config.n_head - 1, sub_network_n_layers=config.n_layer, + sub_network_head_size=config.head_size, + sub_network_query_groups=config.n_query_groups - 1, ) params_sub_network = compute_parameters(gpt) params_single_head = ( - (config.n_embd * config.head_size + config.head_size) * 3 * (config.n_layer) - ) + config.n_embd * config.head_size + config.head_size * config.n_embd * 3 + ) * (config.n_layer) assert params_sub_network == params_super_network - params_single_head # remove larger part of the network diff --git a/test/test_pruners.py b/test/test_pruners.py index 51b5e48..71dfb77 100644 --- a/test/test_pruners.py +++ b/test/test_pruners.py @@ -1,6 +1,7 @@ from __future__ import annotations import pytest +import torch from whittle.models.gpt import GPT, Config from whittle.pruning.pruners.magnitude import MagnitudePruner @@ -27,6 +28,7 @@ ], ) def test_model_pruning(model_info, mock_tokenizer): + torch.manual_seed(0) config = Config.from_name( model_info["config_name"], block_size=6, diff --git a/whittle/lora/__init__.py b/whittle/lora/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/whittle/lora/config.py b/whittle/lora/config.py new file mode 100644 index 0000000..36ee00a --- /dev/null +++ b/whittle/lora/config.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from litgpt.config import Config as BaseConfig + +from whittle.lora.lora_mlps import ( + LoRAGemmaMLP as GemmaMLP, + LoRAGptNeoxMLP as GptNeoxMLP, + LoRALLaMAMLP as LLaMAMLP, +) + + +@dataclass +class LoRAConfig(BaseConfig): + """ + Args: + lora_r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of + the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) + lora_alpha: alpha is needed for scaling updates as alpha/r + "This scaling helps to reduce the need to retune hyperparameters when we vary r" + https://arxiv.org/pdf/2106.09685.pdf (section 4.1) + lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) + lora_query: whether to apply LoRA to the query + lora_key: whether to apply LoRA to the key + lora_value: whether to apply LoRA to the value + lora_projection: whether to apply LoRA to the projection + lora_mlp: whether to apply LoRA to the MLP + lora_head: whether to apply LoRA to the head + lora_emb: whether to apply LoRA to the embedding + """ + + lora_r: int = 4 + lora_alpha: int = 1 + lora_dropout: float = 0.0 + lora_query: bool = False + lora_key: bool = False + lora_value: bool = False + lora_projection: bool = False + lora_mlp: bool = False + lora_head: bool = False + lora_emb: bool = False + + @property + def mlp_class(self) -> type: + if self.mlp_class_name == "GptNeoxMLP": + return GptNeoxMLP + elif self.mlp_class_name == "LLaMAMLP": + return LLaMAMLP + elif self.mlp_class_name == "GemmaMLP": + return GemmaMLP + else: + raise ValueError(f"Unknown MLP class: {self.mlp_class_name}") diff --git a/whittle/lora/lora_attention.py b/whittle/lora/lora_attention.py new file mode 100644 index 0000000..a9823cb --- /dev/null +++ b/whittle/lora/lora_attention.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from typing import Any + +from litgpt.model import KVCache +from litgpt.utils import map_old_state_dict_weights + +from whittle.lora.config import LoRAConfig as Config +from whittle.lora.lora_linear import LoRALinearProj +from whittle.lora.lora_qkv_linear import LoRAQKVLinear +from whittle.models.gpt.blocks.causal_self_attention import ( + CausalSelfAttention as BaseCausalSelfAttention, +) + + +class CausalSelfAttention(BaseCausalSelfAttention): + def __init__(self, config: Config, block_idx: int) -> None: + # Skip the parent class __init__ altogether and replace it to avoid + # useless allocations + super().__init__(config=config, block_idx=block_idx) + shape = (config.n_head + 2 * config.n_query_groups) * config.head_size + # key, query, value projections for all heads, but in a batch + self.attn = LoRAQKVLinear( + config=config, + in_features=config.n_embd, + out_features=shape, + r=config.lora_r, + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + enable_lora=(config.lora_query, config.lora_key, config.lora_value), + bias=config.bias, + # for MQA/GQA support + head_size=config.head_size, + n_head=config.n_head, + n_query_groups=config.n_query_groups, + fix_head_size=config.fix_head_size, + ) + # output projection + # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` + self.proj = LoRALinearProj( + config.head_size * config.n_head, + config.n_embd, + bias=config.bias, + r=(config.lora_r if config.lora_projection else 0), + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + ) + # disabled by default + self.kv_cache: KVCache | None = None + + self.config = config + self.apply_sliding_window_attention = ( + config.sliding_window_size is not None + and block_idx % config.sliding_window_layer_placing == 0 + ) + + # Set current sub-network to super-network + self.sub_network_n_embd = self.config.n_embd + self.sub_network_n_head = self.config.n_head + self.sub_network_head_size = self.config.head_size + self.sub_network_qkv_shape = ( + self.config.n_head + 2 * self.config.n_query_groups + ) * self.config.head_size + self.sub_network_query_groups = self.config.n_query_groups + self.sub_network_q_per_kv = ( + self.sub_network_n_head // self.sub_network_query_groups + ) + self.sub_attention_scaler = self.config.attention_scores_scalar + self.q_per_kv = self.config.n_head // self.config.n_query_groups + self.qkv_indices = None + self.proj_indices = None + + def set_sub_network( + self, + sub_network_n_embd: int, + sub_network_n_head: int, + sub_network_query_groups: int, + sub_network_head_size: int, + ): + """ + Sets the CausalSelfAttention block to the specified sub-network dimensionality. + + Args: + sub_network_n_embd: Embedding dimension of the sub-network + sub_network_n_head: Number of attention heads in the sub-network + sub_network_query_groups: Number of query groups for grouped-query attention (GQA). + sub_network_head_size: Size of each attention head in the sub-network. + """ + self.sub_network_n_embd = ( + sub_network_n_embd if sub_network_n_embd else self.config.n_embd + ) + self.sub_network_n_head = ( + sub_network_n_head if sub_network_n_head else self.config.n_head + ) + self.sub_network_query_groups = ( + sub_network_query_groups + if sub_network_query_groups + else self.config.n_query_groups + ) + self.sub_network_head_size = ( + sub_network_head_size if sub_network_head_size else self.config.head_size + ) + if self.config.n_query_groups == 1: + q_per_kv = self.sub_network_n_head + self.sub_network_query_groups = 1 + elif ( + self.config.n_head != self.config.n_query_groups + and self.config.n_query_groups != 1 + ): + self.sub_network_query_groups = ( + sub_network_query_groups + if sub_network_query_groups + else self.config.n_query_groups + ) + q_per_kv = self.sub_network_n_head // self.config.n_query_groups + elif self.config.n_head == self.config.n_query_groups: + q_per_kv = 1 + self.sub_network_query_groups = self.sub_network_n_head + self.sub_network_qkv_shape = ( + (q_per_kv + 2) * self.sub_network_head_size * self.sub_network_query_groups + ) + self.sub_network_q_per_kv = int(q_per_kv) + self.qkv_indices = self.get_qkv_indices() + self.attn.set_sub_network( + self.sub_network_n_embd, + self.sub_network_qkv_shape, + qkv_indices=self.qkv_indices, + sub_network_n_head=self.sub_network_n_head, + sub_network_query_groups=self.sub_network_query_groups, + sub_network_head_size=self.sub_network_head_size, + sub_network_q_per_kv=self.q_per_kv, + ) + self.proj_indices = self.get_proj_indices() + self.proj.set_sub_network( + self.sub_network_head_size + * self.sub_network_query_groups + * self.sub_network_q_per_kv, + self.sub_network_n_embd, + self.proj_indices, + ) + if self.config.attention_scores_scalar: + self.sub_attention_scaler = self.sub_network_n_embd // self.sub_network_n_head + else: + self.sub_attention_scaler = self.config.attention_scores_scalar + + def _load_from_state_dict( + self, state_dict: dict, prefix: str, *args: Any, **kwargs: Any + ) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "attn.weight": "attn.linear.weight", + "attn.bias": "attn.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) diff --git a/whittle/lora/lora_block.py b/whittle/lora/lora_block.py new file mode 100644 index 0000000..17a9c6e --- /dev/null +++ b/whittle/lora/lora_block.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import torch.nn as nn + +from whittle.lora.config import LoRAConfig as Config +from whittle.lora.lora_attention import CausalSelfAttention +from whittle.lora.lora_mlps import ( + LoRAGemmaMLP as GemmaMLP, + LoRAGptNeoxMLP as GptNeoxMLP, + LoRALLaMAMLP as LLaMAMLP, +) +from whittle.models.gpt.blocks import Block as BaseBlock +from whittle.modules.layernorm import LayerNorm +from whittle.modules.rmsnorm import RMSNorm + + +class LoRABlock(BaseBlock): + def __init__(self, config: Config, block_idx: int) -> None: + super().__init__(config, block_idx) + self.config = config + if not config.parallel_residual and config.shared_attention_norm: + raise NotImplementedError( + "No checkpoint amongst the ones we support uses this configuration" + " (non-parallel residual and shared attention norm)." + ) + + self.norm_1 = self.norm_class()(config.n_embd, eps=config.norm_eps) + self.post_attention_norm = ( + self.norm_class()(config.n_embd, eps=config.norm_eps) + if config.post_attention_norm + else nn.Identity() + ) + self.attn = CausalSelfAttention(config, block_idx) + self.norm_2: LayerNorm | RMSNorm | None = ( + None + if config.shared_attention_norm + else self.norm_class()(config.n_embd, eps=config.norm_eps) + ) + self.mlp = self.mlp_class()(config) + self.post_mlp_norm = ( + self.norm_class()(config.n_embd, eps=config.norm_eps) + if config.post_mlp_norm + else nn.Identity() + ) + # Set current sub-network to super-network + self.sub_network_n_embd = self.config.n_embd + self.sub_network_intermediate_size = self.config.intermediate_size + self.sub_network_num_heads = self.config.n_head + + def mlp_class(self): + # `self._mlp_class` cannot be the type to keep the config json serializable + if self.config.mlp_class_name == "LLaMAMLP": + return LLaMAMLP + elif self.config.mlp_class_name == "GemmaMLP": + return GemmaMLP + elif self.config.mlp_class_name == "GptNeoxMLP": + return GptNeoxMLP + else: + raise ValueError(f"Unknown MLP class: {self.config._mlp_class}") diff --git a/whittle/lora/lora_embedding.py b/whittle/lora/lora_embedding.py new file mode 100644 index 0000000..867359d --- /dev/null +++ b/whittle/lora/lora_embedding.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import math +from typing import Any + +import torch +import torch.nn as nn +from litgpt.lora import LoRALayer +from torch.nn import functional as F + +from whittle.modules.embedding import Embedding + + +class LoRAEmbedding(LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + # ↓ this part is for pretrained weights + num_embeddings: int, + embedding_dim: int, + # ↓ the remaining part is for LoRA + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs: Any, + ): + """LoRA wrapper around linear class. + + This class has three weight matrices: + 1. Pretrained weights are stored as `self.linear.weight` + 2. LoRA A matrix as `self.lora_A` + 3. LoRA B matrix as `self.lora_B` + Only LoRA's A and B matrices are updated, pretrained weights stay frozen. + + Args: + num_embeddings: Number of embeddings in the vocabulary. + embedding_dim: Dimension of the embedding vectors. + r: Rank of the weight update matrices. + lora_alpha: Alpha is needed for scaling updates as alpha/r. + lora_dropout: Dropout that is applied on the input in the LoRA branch (before multiplying by matrix A). + **kwargs: Additional arguments to be passed to the `torch.nn.Embedding` constructor. + """ + super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) + self.embedding = Embedding(num_embeddings, embedding_dim, **kwargs) + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.sub_network_embedding_dim = embedding_dim + self.merged: bool = False + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(torch.empty((r, num_embeddings))) + self.lora_B = nn.Parameter(torch.empty((embedding_dim, r))) + self.scaling = self.lora_alpha / self.r + self.reset_parameters() + + def set_sub_network(self, sub_network_embedding_dim: int): + self.sub_network_embedding_dim = sub_network_embedding_dim + self.embedding.set_sub_network(sub_network_embedding_dim) + self.sub_network_embedding_dim = sub_network_embedding_dim + + def reset_super_network(self): + self.sub_network_embedding_dim = self.embedding_dim + self.embedding.set_sub_network(self.sub_network_embedding_dim) + + def reset_parameters(self) -> None: + """Reset all the weights, even including pretrained ones.""" + if hasattr(self, "lora_A"): + # initialize A the same way as the default for nn.Linear and B to zero + # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314 + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def get_lora_AB(self) -> torch.Tensor: + """Return merged lora_A and lora_B matrices with the same shape as the pretrained weights.""" + return ( + self.lora_B[: self.sub_network_embedding_dim, :] @ self.lora_A + ) * self.scaling + + def merge(self) -> None: + """Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" + if self.r > 0 and not self.merged: + pretrained_dtype = self.embedding.weight.data.dtype + lora_data = self.get_lora_AB().transpose(0, 1) + # if only the pretrained are in quantized form - dequantize, sum with LoRA and quantize the result + if pretrained_dtype == torch.uint8: + import bitsandbytes as bnb + + weight = self.embedding.weight + # dequantize the pretrained weights + weight_data = bnb.functional.dequantize_4bit( + weight.data, weight.quant_state + ).to(lora_data.dtype) + # add pretrained and LoRA weights + weight_data += lora_data + # assign updated weights and quantize by moving to CUDA device + self.embedding.weight = bnb.nn.Params4bit( + weight_data, requires_grad=False, **weight.__dict__ + ) + self.embedding.weight.cuda(weight.device) + else: + # self.linear might be on CPU and lora_data on CUDA + # the inplace add will preserve the dtype of linear.weight + self.embedding.weight.data += lora_data.to( + device=self.embedding.weight.data.device + ) + self.merged = True + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass; + # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights + pretrained = self.embedding(x) + if self.r == 0 or self.merged: + return pretrained + x = F.embedding( + x, + self.lora_A.transpose(0, 1), + self.embedding.padding_idx, + self.embedding.max_norm, + self.embedding.norm_type, + self.embedding.scale_grad_by_freq, + self.embedding.sparse, + ) + lora = ( + x @ self.lora_B[: self.sub_network_embedding_dim, :].transpose(0, 1) + ) * self.scaling + return pretrained + lora diff --git a/whittle/lora/lora_gpt.py b/whittle/lora/lora_gpt.py new file mode 100644 index 0000000..e8c2303 --- /dev/null +++ b/whittle/lora/lora_gpt.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from typing import Any +from typing_extensions import Self + +import torch +import torch.nn as nn +from litgpt.utils import map_old_state_dict_weights + +from whittle.lora.config import LoRAConfig as Config +from whittle.lora.lora_block import LoRABlock as Block +from whittle.lora.lora_embedding import LoRAEmbedding + +# from whittle.models.gpt.blocks import Block +from whittle.lora.lora_linear import LoRALinear +from whittle.models.gpt.model import GPT as BaseModel + + +class GPT(BaseModel): + def __init__(self, config: Config) -> None: + super().__init__(config) + assert config.padded_vocab_size is not None + self.config = config + self.lm_head = LoRALinear( + config.n_embd, + config.padded_vocab_size, + bias=config.lm_head_bias, + r=(config.lora_r if config.lora_head else 0), + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + ) + self.transformer = nn.ModuleDict( + dict( + wte=LoRAEmbedding( + config.padded_vocab_size, + config.n_embd, + r=(config.lora_r if config.lora_emb else 0), + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + ), + h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), + ln_f=self.norm_class(config.n_embd, eps=config.norm_eps), + ) + ) + self.max_layer = config.n_layer + self.max_seq_length = self.config.block_size + self.mask_cache: torch.Tensor | None = None + + # Set current sub-network to super-network + self.sub_network_n_embd = self.config.n_embd + self.sub_network_intermediate_size = self.config.intermediate_size + self.sub_network_num_heads = self.config.n_head + self.sub_network_n_layers = self.config.n_layer + self.sub_network_head_size: int | None = self.config.head_size + self.sub_network_query_groups: int | None = self.config.n_query_groups + self.sub_network_rope_n_elem = self.config.rope_n_elem + self.cos: torch.Tensor + self.sin: torch.Tensor + self.config.is_encoder_decoder = False + self.main_input_name = "input_pos" + self._supports_cache_class = True + self.sub_network_head_size = None + + def forward( + self, + idx: torch.Tensor, + input_pos: torch.Tensor | None = None, + lm_head_chunk_size: int = 0, + ) -> torch.Tensor | list[torch.Tensor]: + T = idx.size(1) + if self.max_seq_length < T: + raise ValueError( + f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}." + ) + + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + if self.config.scale_embeddings: + x = x * torch.tensor(self.sub_network_n_embd**0.5, dtype=x.dtype) + for i in range(self.sub_network_n_layers): + block = self.transformer.h[i] + + cos, sin = self.cos.to(idx.device), self.sin.to(idx.device) + cos, sin, mask = self.process_rope_cache(cos, sin, input_pos, T) + x = block(x, cos, sin, mask, input_pos) + + x = self.transformer.ln_f(x) + if lm_head_chunk_size > 0: + # chunk the lm head logits to reduce the peak memory used by autograd + return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] + x = self.lm_head(x) # (b, t, vocab_size) + if self.config.final_logit_softcapping is not None: + x = ( + torch.tanh(x / self.config.final_logit_softcapping) + * self.config.final_logit_softcapping + ) + return x + # return self.lm_head(x) # (b, t, vocab_size) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def _init_weights(self, module: nn.Module) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" + super()._init_weights(module) + if isinstance(module, LoRALinear): + module.reset_parameters() + + def _load_from_state_dict( + self, state_dict: dict, prefix: str, *args: Any, **kwargs: Any + ) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "lm_head.weight": "lm_head.linear.weight", + "lm_head.bias": "lm_head.linear.bias", + "transformer.wte.weight": "transformer.wte.embedding.weight", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) diff --git a/whittle/lora/lora_linear.py b/whittle/lora/lora_linear.py new file mode 100644 index 0000000..b0ffe73 --- /dev/null +++ b/whittle/lora/lora_linear.py @@ -0,0 +1,350 @@ +from __future__ import annotations + +import math +from typing import Any + +import torch +import torch.nn as nn +from litgpt.lora import LoRALayer + +from whittle.modules.linear import Linear, LinearProj, LinearQKV + + +class LoRALinear(LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + # ↓ this part is for pretrained weights + in_features: int, + out_features: int, + # ↓ the remaining part is for LoRA + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs: Any, + ): + """LoRA wrapper around linear class. + + This class has three weight matrices: + 1. Pretrained weights are stored as `self.linear.weight` + 2. LoRA A matrix as `self.lora_A` + 3. LoRA B matrix as `self.lora_B` + Only LoRA's A and B matrices are updated, pretrained weights stay frozen. + + Args: + in_features: number of input features of the pretrained weights + out_features: number of output features of the pretrained weights + r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of + the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) + lora_alpha: alpha is needed for scaling updates as alpha/r + "This scaling helps to reduce the need to retune hyperparameters when we vary r" + https://arxiv.org/pdf/2106.09685.pdf (section 4.1) + lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) + """ + # call super for both + super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) + + self.linear = Linear(in_features, out_features, **kwargs) + + self.use_bias = self.linear.use_bias + self.in_features = in_features + self.out_features = out_features + self.sub_network_in_features = in_features + self.sub_network_out_features = out_features + self.merged: bool = False + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(torch.empty((r, in_features))) + self.lora_B = nn.Parameter(torch.empty((out_features, r))) + self.scaling = self.lora_alpha / self.r + self.reset_parameters() + + def set_sub_network( + self, + sub_network_in_features: int, + sub_network_out_features: int, + ): + self.sub_network_in_features = sub_network_in_features + self.sub_network_out_features = sub_network_out_features + self.linear.set_sub_network(sub_network_in_features, sub_network_out_features) + + def reset_super_network(self): + self.sub_network_in_features = self.in_features + self.sub_network_out_features = self.out_features + self.linear.set_sub_network( + self.sub_network_in_features, self.sub_network_out_features + ) + + def reset_parameters(self) -> None: + """Reset all the weights, even including pretrained ones.""" + if hasattr(self, "lora_A"): + # initialize A the same way as the default for nn.Linear and B to zero + # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314 + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def get_lora_AB(self) -> torch.Tensor: + """Return merged lora_A and lora_B matrices with the same shape as the pretrained weights.""" + return ( + self.lora_B[self.sub_network_out_features, :] + @ self.lora_A[:, : self.sub_network_in_features] + ) * self.scaling + + def merge(self) -> None: + """Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" + if self.r > 0 and not self.merged: + pretrained_dtype = self.linear.weight.data.dtype + lora_data = self.get_lora_AB() + # if only the pretrained are in quantized form - dequantize, sum with LoRA and quantize the result + if pretrained_dtype == torch.uint8: + import bitsandbytes as bnb + + weight = self.linear.weight + # dequantize the pretrained weights + weight_data = bnb.functional.dequantize_4bit( + weight.data, weight.quant_state + ).to(lora_data.dtype) + # add pretrained and LoRA weights + weight_data += lora_data + # assign updated weights and quantize by moving to CUDA device + self.linear.weight = bnb.nn.Params4bit( + weight_data, requires_grad=False, **weight.__dict__ + ) + self.linear.weight.cuda(weight.device) + else: + # self.linear might be on CPU and lora_data on CUDA + # the inplace add will preserve the dtype of linear.weight + self.linear.weight.data += lora_data.to( + device=self.linear.weight.data.device + ) + self.merged = True + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass; + # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights + pretrained = self.linear(x) + if self.r == 0 or self.merged: + return pretrained + lora = ( + self.lora_dropout(x) + @ self.lora_A[:, : self.sub_network_in_features].transpose(0, 1) + @ self.lora_B[: self.sub_network_out_features, :].transpose(0, 1) + ) * self.scaling + return pretrained + lora + + +class LoRALinearQKV(LoRALayer): + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs: Any, + ): + # Call LoRALayer's constructor with all necessary arguments + super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) + + # Call LinearQKV's constructor with the required arguments + self.linear = LinearQKV(in_features, out_features, **kwargs) + self.in_features = in_features + self.out_features = out_features + self.sub_network_in_features = in_features + self.sub_network_out_features = out_features + self.use_bias = self.linear.use_bias + self.qkv_indices = None + self.merged = False + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(torch.empty((r, in_features))) + self.lora_B = nn.Parameter(torch.empty((out_features, r))) + self.scaling = self.lora_alpha / self.r + self.reset_parameters() + + def set_sub_network( + self, sub_network_in_features, sub_network_out_features, qkv_indices=None + ): + self.sub_network_in_features = sub_network_in_features + self.sub_network_out_features = sub_network_out_features + self.linear.set_sub_network( + sub_network_in_features, sub_network_out_features, qkv_indices + ) + self.qkv_indices = qkv_indices + + def reset_super_network(self): + self.sub_network_in_features = self.in_features + self.sub_network_out_features = self.out_features + self.linear.reset_super_network() + self.qkv_indices = None + + def reset_parameters(self) -> None: + """Reset all the weights, even including pretrained ones.""" + if hasattr(self, "lora_A"): + # initialize A the same way as the default for nn.Linear and B to zero + # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314 + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def get_lora_AB(self) -> torch.Tensor: + """Return merged lora_A and lora_B matrices with the same shape as the pretrained weights.""" + return ( + self.lora_B[self.sub_network_out_features, :] + @ self.lora_A[:, : self.sub_network_in_features] + ) * self.scaling + + def merge(self) -> None: + """Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" + if self.r > 0 and not self.merged: + pretrained_dtype = self.linear.weight.data.dtype + lora_data = self.get_lora_AB() + # if only the pretrained are in quantized form - dequantize, sum with LoRA and quantize the result + if pretrained_dtype == torch.uint8: + import bitsandbytes as bnb + + weight = self.linear.weight + # dequantize the pretrained weights + weight_data = bnb.functional.dequantize_4bit( + weight.data, weight.quant_state + ).to(lora_data.dtype) + # add pretrained and LoRA weights + weight_data += lora_data + # assign updated weights and quantize by moving to CUDA device + self.linear.weight = bnb.nn.Params4bit( + weight_data, requires_grad=False, **weight.__dict__ + ) + self.linear.weight.cuda(weight.device) + else: + # self.linear might be on CPU and lora_data on CUDA + # the inplace add will preserve the dtype of linear.weight + self.linear.weight.data += lora_data.to( + device=self.linear.weight.data.device + ) + self.merged = True + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass; + # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights + pretrained = self.linear(x) + if self.r == 0 or self.merged: + return pretrained + if self.qkv_indices is not None: + lora = ( + self.lora_dropout(x) + @ self.lora_A[:, : self.sub_network_in_features].transpose(0, 1) + @ self.lora_B[self.qkv_indices, :].transpose(0, 1) + ) * self.scaling + else: + lora = ( + self.lora_dropout(x) + @ self.lora_A[:, : self.sub_network_in_features].transpose(0, 1) + @ self.lora_B[: self.sub_network_out_features, :].transpose(0, 1) + ) * self.scaling + return pretrained + lora + + +class LoRALinearProj(LoRALayer): + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs: Any, + ): + super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) + self.linear = LinearProj(in_features, out_features, **kwargs) + self.use_bias = self.linear.use_bias + self.in_features = in_features + self.out_features = out_features + self.sub_network_in_features = in_features + self.sub_network_out_features = out_features + self.proj_indices = None + self.merged = False + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(torch.empty((r, in_features))) + self.lora_B = nn.Parameter(torch.empty((out_features, r))) + self.scaling = self.lora_alpha / self.r + self.reset_parameters() + + def set_sub_network( + self, sub_network_in_features, sub_network_out_features, proj_indices=None + ): + self.sub_network_in_features = sub_network_in_features + self.sub_network_out_features = sub_network_out_features + self.linear.set_sub_network( + sub_network_in_features, sub_network_out_features, proj_indices + ) + self.proj_indices = proj_indices + + def reset_super_network(self): + self.sub_network_in_features = self.in_features + self.sub_network_out_features = self.out_features + self.linear.reset_super_network() + self.proj_indices = None + + def reset_parameters(self) -> None: + """Reset all the weights, even including pretrained ones.""" + if hasattr(self, "lora_A"): + # initialize A the same way as the default for nn.Linear and B to zero + # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314 + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def get_lora_AB(self) -> torch.Tensor: + """Return merged lora_A and lora_B matrices with the same shape as the pretrained weights.""" + return ( + self.lora_B[self.sub_network_out_features, :] + @ self.lora_A[:, : self.sub_network_in_features] + ) * self.scaling + + def merge(self) -> None: + """Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" + if self.r > 0 and not self.merged: + pretrained_dtype = self.linear.weight.data.dtype + lora_data = self.get_lora_AB() + # if only the pretrained are in quantized form - dequantize, sum with LoRA and quantize the result + if pretrained_dtype == torch.uint8: + import bitsandbytes as bnb + + weight = self.linear.weight + # dequantize the pretrained weights + weight_data = bnb.functional.dequantize_4bit( + weight.data, weight.quant_state + ).to(lora_data.dtype) + # add pretrained and LoRA weights + weight_data += lora_data + # assign updated weights and quantize by moving to CUDA device + self.linear.weight = bnb.nn.Params4bit( + weight_data, requires_grad=False, **weight.__dict__ + ) + self.linear.weight.cuda(weight.device) + else: + # self.linear might be on CPU and lora_data on CUDA + # the inplace add will preserve the dtype of linear.weight + self.linear.weight.data += lora_data.to( + device=self.linear.weight.data.device + ) + self.merged = True + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass; + # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights + pretrained = self.linear(x) + if self.r == 0 or self.merged: + return pretrained + if self.proj_indices is not None: + lora = ( + self.lora_dropout(x) + @ self.lora_A[:, self.proj_indices].transpose(0, 1) + @ self.lora_B[: self.sub_network_out_features, :].transpose(0, 1) + ) * self.scaling + else: + lora = ( + self.lora_dropout(x) + @ self.lora_A[:, : self.sub_network_in_features].transpose(0, 1) + @ self.lora_B[: self.sub_network_out_features, :].transpose(0, 1) + ) * self.scaling + return pretrained + lora diff --git a/whittle/lora/lora_mlps.py b/whittle/lora/lora_mlps.py new file mode 100644 index 0000000..3958a6c --- /dev/null +++ b/whittle/lora/lora_mlps.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from typing import Any + +import torch +from litgpt.config import Config +from litgpt.utils import map_old_state_dict_weights + +from whittle.lora.lora_linear import LoRALinear +from whittle.models.gpt.blocks.mlp import ( + GptNeoxMLP as GptNeoxMLPBase, + LLaMAMLP as LLaMAMLPBase, +) + + +class LoRAGptNeoxMLP(GptNeoxMLPBase): + def __init__(self, config: Config) -> None: + super().__init__(config) + self.in_features = config.n_embd + self.intermediate_size = config.intermediate_size + self.fc = LoRALinear( + config.n_embd, + config.intermediate_size, + bias=config.bias, + r=(config.lora_r if config.lora_mlp else 0), + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + ) + self.proj = LoRALinear( + config.intermediate_size, + config.n_embd, + bias=config.bias, + r=(config.lora_r if config.lora_mlp else 0), + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + ) + + self.config = config + + def _load_from_state_dict( + self, state_dict: dict, prefix: str, *args: Any, **kwargs: Any + ) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "fc.weight": "fc.linear.weight", + "fc.bias": "fc.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class LoRALLaMAMLP(LLaMAMLPBase): + def __init__(self, config: Config) -> None: + super().__init__(config) + self.in_features = config.n_embd + self.intermediate_size = config.intermediate_size + self.fc_1 = LoRALinear( + config.n_embd, + config.intermediate_size, + bias=config.bias, + r=(config.lora_r if config.lora_mlp else 0), + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + ) + self.fc_2 = LoRALinear( + config.n_embd, + config.intermediate_size, + bias=config.bias, + r=(config.lora_r if config.lora_mlp else 0), + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + ) + self.proj = LoRALinear( + config.intermediate_size, + config.n_embd, + bias=config.bias, + r=(config.lora_r if config.lora_mlp else 0), + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + ) + + self.config = config + + def _load_from_state_dict( + self, state_dict: dict, prefix: str, *args: Any, **kwargs: Any + ) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "fc_1.weight": "fc_1.linear.weight", + "fc_1.bias": "fc_1.linear.bias", + "fc_2.weight": "fc_2.linear.weight", + "fc_2.bias": "fc_2.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class LoRAGemmaMLP(LoRALLaMAMLP): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fc_1 = self.fc_1(x) + x_fc_2 = self.fc_2(x) + x = ( + torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate) + * x_fc_2 + ) + return self.proj(x) diff --git a/whittle/lora/lora_qkv_linear.py b/whittle/lora/lora_qkv_linear.py new file mode 100644 index 0000000..693d4e7 --- /dev/null +++ b/whittle/lora/lora_qkv_linear.py @@ -0,0 +1,384 @@ +from __future__ import annotations + +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from whittle.lora.lora_linear import LoRALayer, LoRALinearQKV + + +class LoRAQKVLinear(LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + # ↓ this part is for pretrained weights + config, + in_features: int, + out_features: int, + # ↓ the remaining part is for LoRA + head_size: int, + n_head: int, + n_query_groups: int, + fix_head_size: bool, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + enable_lora: bool | tuple[bool, bool, bool] = False, + **kwargs: Any, + ): + super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) + self.config = config + assert out_features == (n_head + 2 * n_query_groups) * head_size + self.linear = LoRALinearQKV(in_features, out_features, **kwargs) + self.use_bias = self.linear.use_bias + self.head_size = head_size + self.fix_head_size = fix_head_size + self.n_head = n_head + self.in_features = in_features + self.out_features = out_features + self.n_query_groups = n_query_groups + if isinstance(enable_lora, bool): + enable_lora = (enable_lora, enable_lora, enable_lora) + assert len(enable_lora) == 3 + self.enable_lora = enable_lora + self.sub_network_in_features = in_features + self.sub_network_out_features = out_features + self.sub_network_head_size = head_size + self.sub_network_n_head = n_head + self.sub_network_query_groups = n_query_groups + self.q_per_kv = n_head // n_query_groups + self.sub_network_q_per_kv = self.q_per_kv + self.qkv_indices = None + # Actual trainable parameters + # To better understand initialization let's imagine that we have such parameters: + # ⚬ in_features: 128 (embeddings_size) + # ⚬ out_features: 384 (3 * embedding_size) + # ⚬ r: 2 + # ⚬ enable_lora: [True, False, True] + if r > 0 and any(enable_lora): + self.lora_A = nn.Parameter( + torch.empty((r * sum(enable_lora), in_features)) + ) # (4, 128) + self.enable_q, self.enable_k, self.enable_v = enable_lora + # qkv_shapes will be used to split a tensor with weights correctly + qkv_shapes = ( + # if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`) + # might not be equal to `head_size * n_head`, thus we use it directly here + head_size * n_head * self.enable_q, + head_size * n_query_groups * self.enable_k, + head_size * n_query_groups * self.enable_v, + ) + self.qkv_shapes = [s for s in qkv_shapes if s] + self.lora_B = nn.Parameter(torch.empty(sum(self.qkv_shapes), r)) # (256, 2)) + # Notes about shapes above + # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices; + # 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in + # F.linear function weights are automatically transposed. In addition conv1d requires channels to + # be before seq length + # - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is + # 128*2; 2 tells to have two channels per group for group convolution + + # Scaling: + # This balances the pretrained model`s knowledge and the new task-specific adaptation + # https://lightning.ai/pages/community/tutorial/lora-llm/ + # So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set + # alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can + # tune these values to your needs. This value can be even slightly greater than 1.0! + # https://github.com/cloneofsimo/lora + self.scaling = self.lora_alpha / self.r + + self.reset_parameters() + + def set_sub_network( + self, + sub_network_in_features: int, + sub_network_out_features: int, + qkv_indices=None, + sub_network_n_head=None, + sub_network_query_groups=None, + sub_network_head_size=None, + sub_network_q_per_kv=None, + ): + self.sub_network_in_features = sub_network_in_features + self.sub_network_out_features = sub_network_out_features + self.sub_network_n_head = sub_network_n_head + self.sub_network_query_groups = sub_network_query_groups + self.sub_network_head_size = sub_network_head_size + self.sub_network_q_per_kv = sub_network_q_per_kv + self.linear.set_sub_network( + sub_network_in_features, sub_network_out_features, qkv_indices + ) + self.qkv_indices = qkv_indices + + def reset_super_network(self): + """Resets the dimensionality of the current sub-network to the super-network dimensionality.""" + self.sub_network_in_features = self.in_features + self.sub_network_out_features = self.out_features + self.sub_network_n_embd = self.config.n_embd + self.sub_network_n_head = self.config.n_head + self.q_per_kv = self.config.n_head // self.config.n_query_groups + self.sub_network_head_size = self.config.head_size + self.sub_network_qkv_shape = ( + self.config.n_head + 2 * self.config.n_query_groups + ) * self.config.head_size + self.sub_network_query_groups = self.config.n_query_groups + self.sub_network_q_per_kv = self.q_per_kv + + self.linear.reset_super_network() + self.sub_attention_scaler = self.config.attention_scores_scalar + self.qkv_indices = None + + @property + def lora_ind(self) -> torch.Tensor: + """Lazy creation of a buffer with LoRA indices to overcome the limitation when FSDP with meta device is used.""" + # Indices are needed to properly pad weight updates with zeros. + if not hasattr(self, "_lora_ind"): + enable_q, enable_k, enable_v = self.enable_lora + qkv_group_size = self.sub_network_q_per_kv + 2 + candidate_indices = ( + range(self.sub_network_out_features) + if self.qkv_indices is None + else self.qkv_indices + ) + lora_ind = [] + if enable_q: + q_ind = [ + x + for x in candidate_indices + if (x // self.sub_network_head_size) % qkv_group_size + < qkv_group_size - 2 + ] + lora_ind.extend(q_ind) + if enable_k: + k_ind = [ + x + for x in candidate_indices + if (x // self.sub_network_head_size) % qkv_group_size + == qkv_group_size - 2 + ] + lora_ind.extend(k_ind) + if enable_v: + v_ind = [ + x + for x in candidate_indices + if (x // self.sub_network_head_size) % qkv_group_size + == qkv_group_size - 1 + ] + lora_ind.extend(v_ind) + self.register_buffer( + "_lora_ind", + torch.tensor(lora_ind, device=self.linear.weight.device), + persistent=False, + ) + + return self._lora_ind + + def zero_pad(self, x: torch.Tensor) -> torch.Tensor: + """Properly pad the last dimension of weight updates with zeros. + + If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys, + then the weights update should be: + + [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,], + [....................................], + [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]] + ↑ ↑ ↑ + ________________________________________ + | query | key | value | + ---------------------------------------- + For Llama2's GQA support, Q, K, and V weights are interleaved, so that weights for grouped + queries are adjacent to their associated key and value weights. + For example, suppose we have n_head = 12 with 3 query groups. + Then along the embedding dimension the interleaved weights would look like + + [Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V], + + where each Q, K, and V has size head_size. + + In this case, the previously-described weight update applies separately to each + individual block, so the update will take the form + + [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...], + [.............................................................................], + [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...]] + ↑ ↑ ↑ ↑ ↑ ↑ + ________________________________________________________________________________ + | q block 1 | k block 1 | v block 1 | q block 2 | k block 2 | v block 2 | ... + -------------------------------------------------------------------------------- + Note that in the above diagram, the size of each q block will equal q_per_kv + times the size of each k and v block. + + Args: + x: tensor with weights update that will be padded with zeros if necessary + + Returns: + A tensor with weight updates and zeros for deselected q, k or v + """ + # we need to do zero padding only if LoRA is disabled for one of QKV matrices + if all(self.enable_lora): + return x + + # Let's image that: + # ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size) + # ⚬ embeddings_size: 128 + # ⚬ self.linear.out_features: 384 (3 * embeddings_size) + # ⚬ enable_lora: [True, False, True] + # Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected + # embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but + # only for key updates (this is where self.lora_ind comes in handy) + result = x.new_zeros( + *x.shape[:-1], self.sub_network_out_features + ) # (64, 64, 384) + return result.index_copy_(dim=-1, index=self.lora_ind, source=x) # (64, 64, 384) + + def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries. + + If the number of heads is equal to the number of query groups - grouped queries are disabled + (see scheme in `litgpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized + query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the + input and weight matrices will be splitted in equally sized parts and applied separately (like having multiple + conv layers side by side). + + Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually, + apply each part of the weight matrix to the corresponding input's part and concatenate the result. + + Args: + input: input matrix of shape (B, C, T) + weight: weight matrix of shape (C_output, rank, 1). + "C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class). + + Returns: + A tensor with a shape (B, C_output, T) + + """ + if self.config.n_head == self.config.n_query_groups: + return F.conv1d( + input, weight, groups=sum(self.enable_lora) + ) # (B, C_output, T) + + # Notation: + # ⚬ N: number of enabled LoRA layers (self.enable_lora) + # ⚬ C_output': embeddings size for each LoRA layer (not equal in size) + # ⚬ r: rank of all LoRA layers (equal in size) + + input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T) + qkv_shapes = [ + # if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`) + # might not be equal to `head_size * n_head`, thus we use it directly here + self.sub_network_head_size + * self.sub_network_query_groups + * self.sub_network_q_per_kv + * self.enable_q, + self.sub_network_head_size * self.sub_network_query_groups * self.enable_k, + self.sub_network_head_size * self.sub_network_query_groups * self.enable_v, + ] + qkv_shapes = [s for s in qkv_shapes if s] + weight_splitted = weight.split(qkv_shapes) # N * (C_output', r, 1) + return torch.cat( + [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], + dim=1, # (B, C_output', T) + ) # (B, C_output, T) + + def get_lora_AB(self) -> torch.Tensor: + """Return merged lora_A and lora_B matrices with the same shape as the pretrained weights.""" + # Let's assume that: + # ⚬ self.linear.weight.data: (384, 128) or (3 * embedding_size, embedding_size) + # ⚬ self.lora_A.data: (4, 128) + # ⚬ self.lora_B.data: (256, 2) + qkv_shapes = [ + # if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`) + # might not be equal to `head_size * n_head`, thus we use it directly here + self.sub_network_head_size * self.sub_network_n_head * self.enable_q, + self.sub_network_head_size * self.sub_network_query_groups * self.enable_k, + self.sub_network_head_size * self.sub_network_query_groups * self.enable_v, + ] + qkv_shapes = [s for s in qkv_shapes if s] + if self.qkv_indices is not None: + lora = self.conv1d( + self.lora_A[:, : self.sub_network_in_features].data.unsqueeze( + 0 + ), # (4, 128) -> (1, 4, 128) + self.lora_B[self.qkv_indices, :].data.unsqueeze( + -1 + ), # (256, 2) -> (256, 2, 1) + ).squeeze(0) + else: + lora = self.conv1d( + self.lora_A[:, : self.sub_network_in_features].data.unsqueeze( + 0 + ), # (4, 128) -> (1, 4, 128) + self.lora_B[: sum(qkv_shapes), :].data.unsqueeze( + -1 + ), # (256, 2) -> (256, 2, 1) + ).squeeze(0) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) + return self.zero_pad( + lora.T * self.scaling + ).T # (256, 128) after zero_pad (384, 128) + + def merge(self) -> None: + """Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" + if self.r > 0 and any(self.enable_lora) and not self.merged: + super().merge() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Do the forward pass. + + If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication. + If not, then multiply pretrained weights with input, apply LoRA on input and do summation. + + Args: + x: input tensor of shape (batch_size, context_length, embedding_size) + + Returns: + Output tensor of shape (batch_size, context_length, 3 * embedding_size) + """ + + # Let's assume that: + # ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size) + # ⚬ self.linear.weight: (384, 128) or (3 * embedding_size, embedding_size) + # ⚬ self.lora_A.data: (4, 128) + # ⚬ self.lora_B.data: (256, 2) + + # if weights are merged or LoRA is disabled (r <= 0 or all `enable_lora` are False) - it's only a regular nn.Linear forward pass; + # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights + pretrained = self.linear(x) + if self.r == 0 or not any(self.enable_lora) or self.merged: + return pretrained + after_A = F.linear( + self.lora_dropout(x), self.lora_A[:, : self.sub_network_in_features] + ) # (64, 64, 128) @ (4, 128) -> (64, 64, 4) + # For F.conv1d: + # ⚬ input: input tensor of shape (mini-batch, in_channels, iW) + # ⚬ weight: filters of shape (out_channels, in_channels/groups, kW) + qkv_shapes = [ + # if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`) + # might not be equal to `head_size * n_head`, thus we use it directly here + self.sub_network_head_size + * self.sub_network_query_groups + * self.sub_network_q_per_kv + * self.enable_q, + self.sub_network_head_size * self.sub_network_query_groups * self.enable_k, + self.sub_network_head_size * self.sub_network_query_groups * self.enable_v, + ] + qkv_shapes = [s for s in qkv_shapes if s] + if self.qkv_indices is not None: + after_B = self.conv1d( + after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64) + self.lora_B[self.qkv_indices, :].unsqueeze(-1), # (256, 2) -> (256, 2, 1) + ).transpose( + -2, -1 + ) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256) + else: + after_B = self.conv1d( + after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64) + self.lora_B[: sum(qkv_shapes), :].unsqueeze( + -1 + ), # (256, 2) -> (256, 2, 1) + ).transpose(-2, -1) + lora = ( + self.zero_pad(after_B) * self.scaling + ) # (64, 64, 256) after zero_pad (64, 64, 384) + return pretrained + lora diff --git a/whittle/metrics/mag.py b/whittle/metrics/mag.py index d6e9554..d4944cf 100644 --- a/whittle/metrics/mag.py +++ b/whittle/metrics/mag.py @@ -85,6 +85,29 @@ def compute_weight_magnitude_embedding(layer): def compute_weight_magnitude_attention(layer): - mag = compute_weight_magnitude_linear_layer(layer.attn) - mag += compute_weight_magnitude_linear_layer(layer.proj) + mag = 0 + if hasattr(layer, "qkv_indices"): + mag = mag + torch.sum( + torch.abs( + layer.attn.weight.data[layer.qkv_indices, :][ + :, 0 : layer.sub_network_n_embd + ] + ) + ) + if layer.attn.bias is not None: + mag = mag + torch.sum(torch.abs(layer.attn.weight.data[layer.qkv_indices])) + else: + mag = mag + compute_weight_magnitude_linear_layer(layer.attn) + if hasattr(layer, "proj_indices"): + mag += torch.sum( + torch.abs( + layer.proj.weight.data[0 : layer.sub_network_n_embd][ + :, layer.proj_indices + ] + ) + ) + if layer.proj.bias is not None: + mag += torch.sum(layer.proj.bias[: layer.sub_network_n_embd]) + else: + mag += compute_weight_magnitude_linear_layer(layer.proj) return float(mag) diff --git a/whittle/metrics/parameters.py b/whittle/metrics/parameters.py index 7c2f6b7..288aafa 100644 --- a/whittle/metrics/parameters.py +++ b/whittle/metrics/parameters.py @@ -44,13 +44,17 @@ def params_layer_normalization(normalization_layer: nn.Module): def params_attention_layer(attention: CausalSelfAttention): dmodel = attention.sub_network_n_embd dhead = attention.sub_network_head_size - num_heads = attention.sub_network_n_head - num_query_groups = attention.sub_network_query_groups - qkv_dim = (num_heads + 2 * num_query_groups) * dhead + if attention.config.n_query_groups != attention.config.n_head: + q_per_kv = attention.sub_network_n_head // attention.config.n_query_groups + num_query_groups = attention.sub_network_query_groups + else: + q_per_kv = 1 + num_query_groups = attention.sub_network_n_head + qkv_dim = (q_per_kv + 2) * dhead * num_query_groups n_attention = dmodel * qkv_dim if attention.attn.use_bias: n_attention += qkv_dim - n_attention += dmodel * dmodel # output + n_attention += dmodel * dhead * num_query_groups * q_per_kv if attention.proj.use_bias: n_attention += dmodel diff --git a/whittle/models/gpt/blocks/causal_self_attention.py b/whittle/models/gpt/blocks/causal_self_attention.py index 2931e4c..bd2e7d9 100644 --- a/whittle/models/gpt/blocks/causal_self_attention.py +++ b/whittle/models/gpt/blocks/causal_self_attention.py @@ -7,7 +7,7 @@ from litgpt import Config from litgpt.model import KVCache, apply_rope -from whittle.modules import Linear +from whittle.modules import LinearProj, LinearQKV class CausalSelfAttention(nn.Module): @@ -17,10 +17,10 @@ def __init__(self, config: Config, block_idx: int) -> None: super().__init__() shape = (config.n_head + 2 * config.n_query_groups) * config.head_size # key, query, value projections for all heads, but in a batch - self.attn = Linear(config.n_embd, shape, bias=config.bias) + self.attn = LinearQKV(config.n_embd, shape, bias=config.bias) # output projection # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` - self.proj = Linear( + self.proj = LinearProj( config.head_size * config.n_head, config.n_embd, bias=config.bias ) # disabled by default @@ -31,18 +31,101 @@ def __init__(self, config: Config, block_idx: int) -> None: ) self.config = config # Set current sub-network to super-network + self.q_per_kv = config.n_head // config.n_query_groups self.sub_network_n_embd = self.config.n_embd self.sub_network_n_head = self.config.n_head self.sub_network_head_size = self.config.head_size self.sub_network_qkv_shape = ( - self.config.n_head + 2 * self.config.n_query_groups - ) * self.config.head_size + (self.q_per_kv + 2) * self.config.head_size * self.config.n_query_groups + ) self.sub_network_query_groups = self.config.n_query_groups - self.sub_network_q_per_kv = ( + self.sub_network_q_per_kv = int( self.sub_network_n_head // self.sub_network_query_groups ) self.sub_attention_scaler = self.config.attention_scores_scalar + def get_qkv_indices(self): + qkv_indices = [] + heads_per_group = self.config.n_head // self.config.n_query_groups + if self.config.n_head == self.config.n_query_groups: + for h in range(self.sub_network_n_head): + # append q + start_q = 3 * h * self.config.head_size + end_q = start_q + self.sub_network_head_size + qkv_indices.extend([i for i in range(start_q, end_q)]) + # append k + start_k = (3 * h + 1) * self.config.head_size + end_k = start_k + self.sub_network_head_size + qkv_indices.extend([i for i in range(start_k, end_k)]) + # append v + start_v = (3 * h + 2) * self.config.head_size + end_v = start_v + self.sub_network_head_size + qkv_indices.extend([i for i in range(start_v, end_v)]) + elif self.config.n_query_groups == 1: + for h in range(self.sub_network_n_head): + start_q = h * self.config.head_size + end_q = start_q + self.sub_network_head_size + qkv_indices.extend([i for i in range(start_q, end_q)]) + end_queries = self.config.n_head * self.config.head_size + qkv_indices.extend( + [i for i in range(end_queries, end_queries + self.sub_network_head_size)] + ) + end_keys = end_queries + self.config.head_size + qkv_indices.extend( + [i for i in range(end_keys, end_keys + self.sub_network_head_size)] + ) + else: + for g in range(self.sub_network_query_groups): + start_q = g * (heads_per_group + 2) * self.config.head_size + for h in range(self.sub_network_q_per_kv): + qkv_indices.extend( + [ + i + for i in range( + start_q + h * self.config.head_size, + start_q + + h * self.config.head_size + + self.sub_network_head_size, + ) + ] + ) + start_k = start_q + heads_per_group * self.config.head_size + qkv_indices.extend( + [i for i in range(start_k, start_k + self.sub_network_head_size)] + ) + start_v = start_k + self.config.head_size + qkv_indices.extend( + [i for i in range(start_v, start_v + self.sub_network_head_size)] + ) + return qkv_indices + + def get_proj_indices(self): + n_head = self.config.n_head + n_query_groups = self.config.n_query_groups + sub_network_n_head = self.sub_network_n_head + heads_per_group = self.config.n_head // self.config.n_query_groups + sub_network_query_groups = self.sub_network_query_groups + sub_network_head_size = self.sub_network_head_size + head_size = self.config.head_size + proj_indices = [] + if n_head == n_query_groups: + for i in range(sub_network_n_head): + proj_indices.extend( + i for i in range(i * head_size, i * head_size + sub_network_head_size) + ) + else: + for g in range(sub_network_query_groups): + start = g * heads_per_group * head_size + for h in range(self.sub_network_q_per_kv): + proj_indices.extend( + i + for i in range( + start + h * head_size, + start + h * head_size + sub_network_head_size, + ) + ) + return proj_indices + def set_sub_network( self, sub_network_n_embd: int, @@ -59,22 +142,51 @@ def set_sub_network( sub_network_query_groups: Number of query groups for grouped-query attention (GQA). sub_network_head_size: Size of each attention head in the sub-network. """ - self.sub_network_n_embd = sub_network_n_embd - self.sub_network_n_head = sub_network_n_head - self.sub_network_query_groups = sub_network_query_groups - self.sub_network_head_size = sub_network_head_size - + self.sub_network_n_embd = ( + sub_network_n_embd if sub_network_n_embd else self.config.n_embd + ) + self.sub_network_n_head = ( + sub_network_n_head if sub_network_n_head else self.config.n_head + ) + self.sub_network_query_groups = ( + sub_network_query_groups + if sub_network_query_groups + else self.config.n_query_groups + ) + self.sub_network_head_size = ( + sub_network_head_size if sub_network_head_size else self.config.head_size + ) + if self.config.n_query_groups == 1: + q_per_kv = self.sub_network_n_head + self.sub_network_query_groups = 1 + elif ( + self.config.n_head != self.config.n_query_groups + and self.config.n_query_groups != 1 + ): + self.sub_network_query_groups = ( + sub_network_query_groups + if sub_network_query_groups + else self.config.n_query_groups + ) + q_per_kv = self.sub_network_n_head // self.config.n_query_groups + elif self.config.n_head == self.config.n_query_groups: + q_per_kv = 1 + self.sub_network_query_groups = self.sub_network_n_head self.sub_network_qkv_shape = ( - self.sub_network_n_head + 2 * self.sub_network_query_groups - ) * self.sub_network_head_size - - self.attn.set_sub_network(self.sub_network_n_embd, self.sub_network_qkv_shape) + (q_per_kv + 2) * self.sub_network_head_size * self.sub_network_query_groups + ) + self.sub_network_q_per_kv = int(q_per_kv) + self.qkv_indices = self.get_qkv_indices() + self.attn.set_sub_network( + self.sub_network_n_embd, self.sub_network_qkv_shape, self.qkv_indices + ) + self.proj_indices = self.get_proj_indices() self.proj.set_sub_network( - self.sub_network_head_size * self.sub_network_n_head, + self.sub_network_head_size + * self.sub_network_query_groups + * self.sub_network_q_per_kv, self.sub_network_n_embd, - ) - self.sub_network_q_per_kv = self.sub_network_n_head // float( - self.sub_network_query_groups + self.proj_indices, ) if self.config.attention_scores_scalar: self.sub_attention_scaler = self.sub_network_n_embd // self.sub_network_n_head @@ -90,7 +202,7 @@ def reset_super_network(self): self.config.n_head + 2 * self.config.n_query_groups ) * self.config.head_size self.sub_network_query_groups = self.config.n_query_groups - self.sub_network_q_per_kv = ( + self.sub_network_q_per_kv = int( self.sub_network_n_head // self.sub_network_query_groups ) self.attn.reset_super_network() @@ -115,8 +227,9 @@ def forward( ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd) qkv = self.attn(x) # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) - q_per_kv = self.sub_network_n_head // self.sub_network_query_groups - total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value + total_qkv = ( + self.sub_network_q_per_kv + 2 + ) # each group has 1+ queries, 1 key, and 1 value qkv = qkv.view( B, @@ -129,25 +242,25 @@ def forward( qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) # split batched computation into three - q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) + q, k, v = qkv.split((self.sub_network_q_per_kv, 1, 1), dim=2) # maybe repeat k and v if for the non multi-head attention cases # training: flash attention requires it # inference: multi-query would require a full kv cache so avoid it to limit its memory usage - if self.sub_network_query_groups != self.sub_network_n_head and ( - input_pos is None or self.config.n_query_groups != 1 - ): + if self.sub_network_query_groups != ( + self.sub_network_query_groups * self.sub_network_q_per_kv + ) and (input_pos is None or self.sub_network_query_groups != 1): k = k.expand( B, self.sub_network_query_groups, - q_per_kv, + self.sub_network_q_per_kv, T, self.sub_network_head_size, ) v = v.expand( B, self.sub_network_query_groups, - q_per_kv, + self.sub_network_q_per_kv, T, self.sub_network_head_size, ) @@ -186,7 +299,11 @@ def forward( mask += sliding_window_bias y = self.scaled_dot_product_attention(q, k, v, mask) y = y.reshape( - B, T, self.sub_network_head_size * self.sub_network_n_head + B, + T, + self.sub_network_head_size + * self.sub_network_q_per_kv + * self.sub_network_query_groups, ) # re-assemble all head outputs side by side return self.proj(y) diff --git a/whittle/models/gpt/extract.py b/whittle/models/gpt/extract.py index 996883c..e84ce8d 100644 --- a/whittle/models/gpt/extract.py +++ b/whittle/models/gpt/extract.py @@ -28,7 +28,12 @@ def extract_current_sub_network(model: GPT) -> GPT: subnet_config.n_embd = model.sub_network_n_embd subnet_config.intermediate_size = model.sub_network_intermediate_size - subnet_config.n_head = model.sub_network_num_heads + if model.config.n_head != model.config.n_query_groups: + subnet_config.n_head = ( + int(model.sub_network_num_heads) // model.config.n_query_groups + ) * model.sub_network_query_groups + else: + subnet_config.n_head = model.sub_network_num_heads subnet_config.n_layer = model.sub_network_n_layers subnet_config.head_size = model.sub_network_head_size subnet_config.n_query_groups = model.sub_network_query_groups @@ -66,10 +71,7 @@ def extract_sub_network(model: GPT, sub_network_config: Config) -> GPT: sub_network_block = sub_network.transformer.h[i] # Attention - state_dict = extract_linear(block.attn.attn) - sub_network_block.attn.attn.load_state_dict(state_dict) - state_dict = extract_linear(block.attn.proj) - sub_network_block.attn.proj.load_state_dict(state_dict) + extract_attention(block.attn, sub_network_block.attn) # MLP extract_mlp(block.mlp, sub_network_block.mlp) @@ -83,6 +85,31 @@ def extract_sub_network(model: GPT, sub_network_config: Config) -> GPT: return sub_network +def extract_attention(super_network_attention, sub_network_attention): + if super_network_attention.qkv_indices is not None: + sub_network_attention.attn.weight.data = super_network_attention.attn.weight.data[ + super_network_attention.qkv_indices, : + ][:, 0 : sub_network_attention.sub_network_n_embd] + if sub_network_attention.attn.bias is not None: + sub_network_attention.attn.bias.data = super_network_attention.attn.bias.data[ + super_network_attention.qkv_indices + ] + else: + state_dict = extract_linear(super_network_attention.attn) + sub_network_attention.attn.load_state_dict(state_dict) + if super_network_attention.proj_indices is not None: + sub_network_attention.proj.weight.data = super_network_attention.proj.weight.data[ + 0 : sub_network_attention.sub_network_n_embd + ][:, super_network_attention.proj_indices] + if sub_network_attention.proj.bias is not None: + sub_network_attention.proj.bias.data = super_network_attention.proj.bias.data[ + 0 : sub_network_attention.sub_network_n_embd + ] + else: + state_dict = extract_linear(super_network_attention.proj) + sub_network_attention.proj.load_state_dict(state_dict) + + def extract_mlp(mlp, sub_mlp): if isinstance(mlp, GptNeoxMLP): state_dict = extract_linear(mlp.fc) @@ -141,7 +168,6 @@ def extract_linear(super_network_linear): super_network_state = super_network_linear.state_dict() in_feat_sub = super_network_linear.sub_network_in_features out_feat_sub = super_network_linear.sub_network_out_features - new_state_dict = OrderedDict() new_state_dict["weight"] = super_network_state["weight"][:out_feat_sub, :in_feat_sub] @@ -155,7 +181,6 @@ def extract_embedding(super_network_embedding): super_network_state = super_network_embedding.state_dict() new_state_dict = OrderedDict() sub_network_embedding_dim = super_network_embedding.sub_network_embedding_dim - new_state_dict["weight"] = super_network_state["weight"][ :, :sub_network_embedding_dim ] diff --git a/whittle/models/gpt/model.py b/whittle/models/gpt/model.py index 942f9d2..be8e654 100644 --- a/whittle/models/gpt/model.py +++ b/whittle/models/gpt/model.py @@ -198,26 +198,42 @@ def set_sub_network( self.sub_network_n_layers = sub_network_n_layers self.transformer.wte.set_sub_network(self.sub_network_n_embd) self.transformer.ln_f.set_sub_network(self.sub_network_n_embd) - if sub_network_query_groups is None: - if self.config.n_query_groups == 1: - self.sub_network_query_groups = 1 - elif self.sub_network_num_heads % self.config.n_query_groups == 0: - self.sub_network_query_groups = self.config.n_query_groups - else: - self.sub_network_query_groups = self.sub_network_num_heads // ( - self.config.n_head // self.config.n_query_groups - ) + if self.config.n_query_groups == 1: + self.sub_network_query_groups = 1 + self.sub_network_num_heads = ( + sub_network_num_heads + if sub_network_num_heads is not None + else self.config.n_head + ) + elif self.config.n_head != self.config.n_query_groups: + self.sub_network_num_heads = ( + sub_network_num_heads + if sub_network_num_heads is not None + else self.config.n_head + ) + self.sub_network_query_groups = ( + sub_network_query_groups + if sub_network_query_groups is not None + else self.config.n_query_groups + ) else: - self.sub_network_query_groups = sub_network_query_groups + self.sub_network_query_groups = ( + sub_network_query_groups + if sub_network_query_groups is not None + else self.config.n_head + ) if self.config.fix_head_size: if sub_network_head_size is None: self.sub_network_head_size = self.config.head_size else: self.sub_network_head_size = sub_network_head_size else: - self.sub_network_head_size = ( - self.sub_network_n_embd // self.sub_network_num_heads - ) + if sub_network_head_size is not None: + self.sub_network_head_size = sub_network_head_size + else: + self.sub_network_head_size = ( + self.sub_network_n_embd // self.sub_network_num_heads + ) for i in range(self.sub_network_n_layers): block = self.transformer.h[i] block.set_sub_network( @@ -240,6 +256,9 @@ def set_sub_network( n_elem=self.sub_network_rope_n_elem, device=self.cos.device, ) + print( + f"Set sub-network to: {self.sub_network_n_embd} embd, {self.sub_network_intermediate_size} intermediate size, {self.sub_network_num_heads} heads, {self.sub_network_n_layers} layers, {self.sub_network_query_groups} query groups, {self.sub_network_head_size} head size" + ) def select_sub_network(self, config: dict[str, Any]) -> None: """ diff --git a/whittle/modules/__init__.py b/whittle/modules/__init__.py index 3ef0ffb..51bb258 100644 --- a/whittle/modules/__init__.py +++ b/whittle/modules/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from .linear import Linear +from .linear import Linear, LinearProj, LinearQKV -__all__ = ["Linear"] +__all__ = ["Linear", "LinearQKV", "LinearProj"] diff --git a/whittle/modules/linear.py b/whittle/modules/linear.py index 873ccbb..6ca4571 100644 --- a/whittle/modules/linear.py +++ b/whittle/modules/linear.py @@ -51,3 +51,133 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: : self.sub_network_out_features, : self.sub_network_in_features ], ) + + +class LinearQKV(nn.Linear): + """An extension of Linear to support QKV Indexing""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ): + super().__init__(in_features, out_features, bias, device, dtype) + + # Set the current sub-network dimensions equal to super-network + self.sub_network_in_features = in_features + self.sub_network_out_features = out_features + self.use_bias = bias + self.qkv_indices = None + + def set_sub_network( + self, + sub_network_in_features: int, + sub_network_out_features: int, + qkv_indices=None, + sub_network_n_head=None, + sub_network_query_groups=None, + sub_network_head_size=None, + sub_network_q_per_kv=None, + ): + """Set the linear transformation dimensions of the current sub-network.""" + self.sub_network_in_features = sub_network_in_features + self.sub_network_out_features = sub_network_out_features + self.qkv_indices = qkv_indices + + def reset_super_network(self): + """Reset the linear transformation dimensions of the current sub-network to the super-network dimensionality.""" + self.sub_network_in_features = self.in_features + self.sub_network_out_features = self.out_features + self.qkv_indices = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_bias: + if self.qkv_indices is not None: + return F.linear( + x, + self.weight[self.qkv_indices, : self.sub_network_in_features], + self.bias[self.qkv_indices], + ) + else: + return F.linear( + x, + self.weight[:, : self.sub_network_in_features], + self.bias, + ) + else: + if self.qkv_indices is not None: + return F.linear( + x, + self.weight[self.qkv_indices, : self.sub_network_in_features], + ) + else: + return F.linear( + x, + self.weight[:, : self.sub_network_in_features], + ) + + +class LinearProj(nn.Linear): + """An extension of Linear to support Projection Indexing""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ): + super().__init__(in_features, out_features, bias, device, dtype) + + # Set the current sub-network dimensions equal to super-network + self.sub_network_in_features = in_features + self.sub_network_out_features = out_features + self.use_bias = bias + self.proj_indices = None + + def set_sub_network( + self, + sub_network_in_features: int, + sub_network_out_features: int, + proj_indices: torch.Tensor, + ): + """Set the linear transformation dimensions of the current sub-network.""" + self.sub_network_in_features = sub_network_in_features + self.sub_network_out_features = sub_network_out_features + self.proj_indices = proj_indices + + def reset_super_network(self): + """Reset the linear transformation dimensions of the current sub-network to the super-network dimensionality.""" + self.sub_network_in_features = self.in_features + self.sub_network_out_features = self.out_features + self.proj_indices = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_bias: + if self.proj_indices is not None: + return F.linear( + x, + self.weight[: self.sub_network_out_features, self.proj_indices], + self.bias[: self.sub_network_out_features], + ) + else: + return F.linear( + x, + self.weight[: self.sub_network_out_features, :], + self.bias[: self.sub_network_out_features], + ) + else: + if self.proj_indices is not None: + return F.linear( + x, + self.weight[: self.sub_network_out_features, self.proj_indices], + ) + else: + return F.linear( + x, + self.weight[: self.sub_network_out_features, :], + ) diff --git a/whittle/pruning/pruners/base_pruner.py b/whittle/pruning/pruners/base_pruner.py index 53a2fa8..a0fafe9 100644 --- a/whittle/pruning/pruners/base_pruner.py +++ b/whittle/pruning/pruners/base_pruner.py @@ -8,7 +8,7 @@ from whittle.models.gpt import GPT from whittle.modules.embedding import Embedding -from whittle.modules.linear import Linear +from whittle.modules.linear import Linear, LinearProj, LinearQKV from whittle.pruning.utils.catcher import Catcher @@ -41,7 +41,7 @@ def __call__( def _find_layers( self, module: nn.Module, - layers: list[type[nn.Module]] = [Linear, Embedding], + layers: list[type[nn.Module]] = [Linear, LinearQKV, LinearProj, Embedding], name: str = "", ) -> dict[str, nn.Module]: """ diff --git a/whittle/pruning/pruners/magnitude.py b/whittle/pruning/pruners/magnitude.py index 90e5c8b..85b79ca 100644 --- a/whittle/pruning/pruners/magnitude.py +++ b/whittle/pruning/pruners/magnitude.py @@ -41,7 +41,7 @@ def _prune( """ - layers = [model.transformer, model.lm_head] + layers = [model.transformer.h] for i in range(len(layers)): layer = layers[i] subset = self._find_layers(layer)