Skip to content

Commit

Permalink
Merge branch 'main' into litgpt-style-cli
Browse files Browse the repository at this point in the history
  • Loading branch information
timurcarstensen authored Feb 12, 2025
2 parents fd7d28e + 3b800f3 commit 8337116
Show file tree
Hide file tree
Showing 30 changed files with 2,743 additions and 135 deletions.
74 changes: 43 additions & 31 deletions test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,47 @@ def copy_subnetwork_weights(sub_network, super_network):
"""

# Ensure both networks are in evaluation mode to avoid any changes in weights during copying.
sub_network.eval()
super_network.eval()

# Get the state dictionaries of both networks
sub_state_dict = sub_network.state_dict()
super_state_dict = super_network.state_dict()

# Iterate over the subnetwork's state dictionary and copy weights to the supernetwork
for layer_name, sub_weight in sub_state_dict.items():
if layer_name in super_state_dict:
super_weight = super_state_dict[layer_name]

# Ensure the subnetwork's weight can fit into the supernetwork's weight tensor
if sub_weight.size() == super_weight.size():
super_state_dict[layer_name] = sub_weight
else:
# Copy the sub_weight values into the corresponding part of super_weight
if len(sub_weight.shape) == 1:
super_weight[0 : sub_weight.shape[0]] = sub_weight
elif len(sub_weight.shape) == 2:
super_weight[0 : sub_weight.shape[0], 0 : sub_weight.shape[1]] = (
sub_weight
)
super_state_dict[layer_name] = super_weight
else:
raise KeyError(f"Layer {layer_name} not found in super-network.")

# Load the modified state dictionary back into the supernetwork
super_network.load_state_dict(super_state_dict)
embd = super_network.sub_network_n_embd
intermediate = super_network.sub_network_intermediate_size
super_network.lm_head.weight.data[:, :embd] = sub_network.lm_head.weight.data
if super_network.lm_head.bias is not None:
super_network.lm_head.bias.data = sub_network.lm_head.bias.data
super_network.transformer.wte.weight.data[:, :embd] = (
sub_network.transformer.wte.weight.data
)
super_network.transformer.ln_f.weight.data[:embd] = (
sub_network.transformer.ln_f.weight.data
)
if super_network.transformer.ln_f.bias is not None:
super_network.transformer.ln_f.bias.data[:embd] = (
sub_network.transformer.ln_f.bias.data
)
for i, block_orig in enumerate(sub_network.transformer.h):
block = super_network.transformer.h[i]
block.attn.attn.weight.data[block.attn.qkv_indices, :][:, :embd] = (
block_orig.attn.attn.weight.data
)
if block.attn.attn.bias is not None:
block.attn.attn.bias.data[block.attn.qkv_indices] = (
block_orig.attn.attn.bias.data
)
block.attn.proj.weight.data[:, block.attn.proj_indices][:embd, :] = (
block_orig.attn.proj.weight.data
)
if block.attn.proj.bias is not None:
block.attn.proj.bias.data[:embd] = block_orig.attn.proj.bias.data
block.mlp.fc.weight.data[:intermediate, :embd] = block_orig.mlp.fc.weight.data
if block.mlp.fc.bias is not None:
block.mlp.fc.bias.data[:intermediate] = block_orig.mlp.fc.bias.data
block.mlp.proj.weight.data[:embd, :intermediate] = block_orig.mlp.proj.weight.data
if block.mlp.proj.bias is not None:
block.mlp.proj.bias.data[:embd] = block_orig.mlp.proj.bias.data
block.norm_1.weight.data[:embd] = block_orig.norm_1.weight.data
block.norm_2.weight.data[:embd] = block_orig.norm_2.weight.data
if block.norm_1.bias is not None:
block.norm_1.bias.data[:embd] = block_orig.norm_1.bias.data
if block.norm_2.bias is not None:
block.norm_2.bias.data[:embd] = block_orig.norm_2.bias.data
return super_network


Expand Down Expand Up @@ -344,8 +356,6 @@ def test_compare_litgpt(self, checkpoint_dir, checkpoint_dir_14m, out_dir):
gpt_14m = GPT(config_14m).to(device)
gpt_14m.name_or_path = "EleutherAI/pythia-14m"
gpt_14m.load_state_dict(torch.load(str(checkpoint_dir_14m / "lit_model.pth")))
gpt = copy_subnetwork_weights(gpt_14m, gpt)
gpt.max_seq_length = config_14m.block_size
gpt.set_sub_network(
sub_network_n_embd=config_14m.n_embd,
sub_network_intermediate_size=config_14m.intermediate_size,
Expand All @@ -354,6 +364,8 @@ def test_compare_litgpt(self, checkpoint_dir, checkpoint_dir_14m, out_dir):
sub_network_query_groups=config_14m.n_query_groups,
sub_network_head_size=config_14m.head_size,
)
gpt = copy_subnetwork_weights(gpt_14m, gpt)
gpt.max_seq_length = config_14m.block_size
convert_and_evaluate(
model=gpt,
out_dir=out_dir,
Expand Down
71 changes: 49 additions & 22 deletions test/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,30 @@ def init_lit_attention(config):
return attention


def init_lit_small_attention(config, base_attention):
def init_lit_small_attention(config, base_attention, attention_super):
attention = LitCausalSelfAttention(config, 2)
torch.manual_seed(0)
slices = tuple(slice(0, s) for s in attention.attn.weight.data.size())
attention.attn.weight.data = base_attention.attn.weight.data[slices]
slices = tuple(slice(0, s) for s in attention.attn.bias.data.size())
attention.attn.bias.data = base_attention.attn.bias.data[slices]
slices = tuple(slice(0, s) for s in attention.attn.weight.data.size())[1]
qkv_indices = (
attention_super.qkv_indices
if attention_super.qkv_indices is not None
else slice(0, attention.attn.weight.data.size()[0])
)
attention.attn.weight.data = base_attention.attn.weight.data[qkv_indices, :][
:, 0 : attention.attn.weight.data.size()[1]
]
attention.attn.bias.data = base_attention.attn.bias.data[qkv_indices]
proj_indices = (
attention_super.proj_indices
if attention_super.proj_indices is not None
else slice(0, attention.proj.weight.data.size()[-1])
)
slices = tuple(slice(0, s) for s in attention.proj.bias.data.size())
attention.proj.bias.data = base_attention.proj.bias.data[slices]
slices = tuple(slice(0, s) for s in attention.proj.weight.data.size())
attention.proj.weight.data = base_attention.proj.weight.data[slices]

attention.proj.weight.data = base_attention.proj.weight.data[
0 : attention.proj.weight.data.size()[0], :
][:, proj_indices]
return attention


Expand All @@ -89,7 +102,7 @@ def test_attention(attention_config):
config = attention_configs[attention_config]["config"]
config.fix_head_size = attention_configs[attention_config]["fix_head_size"]
if not config.fix_head_size:
config.head_size = config.n_embd // config.n_head
config.head_size = 32
config.max_seq_len = 512
config.rope_n_elem = int(config.rotary_percentage * config.head_size)

Expand All @@ -108,20 +121,21 @@ def test_attention(attention_config):
lit_attention = init_lit_attention(config)
out_lit_large = lit_attention(input, mask=mask, cos=cos, sin=sin)
if not config.fix_head_size:
sub_network_head_size = config.n_embd // (2 * config.n_head // 4)
sub_network_head_size = config.head_size // 2
else:
sub_network_head_size = config.head_size
if config.n_query_groups == 1:
sub_network_query_groups = 1
elif (config.n_head // 4) % config.n_query_groups == 0:
sub_network_query_groups = config.n_query_groups
sub_network_n_head = config.n_head // 4
elif config.n_query_groups == config.n_head:
sub_network_n_head = config.n_head // 4
sub_network_query_groups = sub_network_n_head
else:
sub_network_query_groups = (config.n_head // 4) // (
config.n_head // config.n_query_groups
)
sub_network_query_groups = config.n_query_groups // 2
sub_network_n_head = config.n_head // 2
attention.set_sub_network(
sub_network_n_embd=config.n_embd // 2,
sub_network_n_head=config.n_head // 4,
sub_network_n_head=sub_network_n_head,
sub_network_query_groups=sub_network_query_groups,
sub_network_head_size=sub_network_head_size,
)
Expand All @@ -135,18 +149,31 @@ def test_attention(attention_config):

# check that our custom model produces the same output as LitGPT
assert torch.all(out_lit_large == out_large)

config.n_embd = attention.sub_network_n_embd
config.n_head = attention.sub_network_n_head
config.n_query_groups = sub_network_query_groups
config.head_size = sub_network_head_size
if config.n_query_groups == config.n_head:
config.n_head = attention.sub_network_n_head
else:
config.n_head = (
attention.sub_network_n_head // config.n_query_groups
) * attention.sub_network_query_groups
config.n_query_groups = attention.sub_network_query_groups
config.head_size = attention.sub_network_head_size
config.rope_n_elem = int(config.rotary_percentage * config.head_size)

lit_attention_small = init_lit_small_attention(config, lit_attention)
lit_attention_small = init_lit_small_attention(config, lit_attention, attention)
if attention.qkv_indices is not None:
print(lit_attention_small.attn.weight.data.size())
print(
attention.attn.weight.data[attention.qkv_indices, :][
:, 0 : config.n_embd
].size()
)
assert torch.all(
lit_attention_small.attn.weight.data
== attention.attn.weight.data[attention.qkv_indices, :][:, 0 : config.n_embd]
)

out_lit_small = lit_attention_small(
input[:, :, : config.n_embd], mask=mask, cos=cos, sin=sin
)

# check that our sub-networks the same output as equally sized LitGPT attention layer
assert torch.all(out_lit_small == out_small)
9 changes: 5 additions & 4 deletions test/test_checkpoint_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from litgpt.model import GPT as LitGPT
from litgpt.scripts.download import download_from_hub

from whittle.models.gpt.model import GPT as whittleGPT
from whittle.models.gpt.model import GPT as WhittleGPT


@pytest.fixture(scope="session")
Expand All @@ -34,17 +34,18 @@ def test_checkpoint_loading(checkpoint_dir):
# litgpt download --repo_id stabilityai/stablelm-base-alpha-3b
config = Config.from_file(str(checkpoint_dir / "model_config.yaml"))
config.fix_head_size = False
model = whittleGPT(config) # .cuda()
model = WhittleGPT(config) # .cuda()
model.load_state_dict(torch.load(str(checkpoint_dir / "lit_model.pth")))
# test output
model.eval()
sample_intermediate_size = 4 * config.n_embd
sample_intermediate_size = config.intermediate_size
model.set_sub_network(
config.n_embd,
sample_intermediate_size,
config.n_head,
config.n_layer,
config.n_query_groups,
config.head_size,
)

output_whittle = model(input_ids)
assert torch.allclose(output_lit, output_whittle)
Loading

0 comments on commit 8337116

Please sign in to comment.