Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gemma checkpoint support #941

Merged
merged 37 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
034afac
gemma
rasbt Feb 21, 2024
d26ab02
add docs
rasbt Feb 21, 2024
5c1b029
update query head config
rasbt Feb 21, 2024
d77bd4a
apply keras geglu workaround
rasbt Feb 21, 2024
5ae2ad6
Carlos
carmocca Feb 21, 2024
e04be8a
An unfinished, but working 2b variant.
Andrei-Aksionov Feb 21, 2024
e61c449
Gemma-7b now works.
Andrei-Aksionov Feb 22, 2024
9a1f23c
add instruction-finetuned version
rasbt Feb 22, 2024
3ec329f
A test for config to check head_size
Andrei-Aksionov Feb 22, 2024
29a00c2
Update Gemma config
Andrei-Aksionov Feb 22, 2024
bf9d711
Adapter_v2 and LoRA: attn.proj size is head_size * num_heads
Andrei-Aksionov Feb 22, 2024
072c9f6
Adapter_v2 and LoRA: gemmamlp class
Andrei-Aksionov Feb 22, 2024
bed71f1
RMSNorm: unit offset is configurable
Andrei-Aksionov Feb 22, 2024
ec7d01e
Configurable wte output scaling
Andrei-Aksionov Feb 22, 2024
bd0864c
Update tests to supports changes in Config class
Andrei-Aksionov Feb 22, 2024
bf8c9b5
Test for Gemma
Andrei-Aksionov Feb 22, 2024
c78bd6e
conver_hf: reuse llama copy function
Andrei-Aksionov Feb 22, 2024
50ad509
Test Gemma model: use llama weights copying
Andrei-Aksionov Feb 22, 2024
d98df3b
Update convert_lit + test
Andrei-Aksionov Feb 22, 2024
159df75
Merge branch 'main' into gemma
Andrei-Aksionov Feb 22, 2024
628c7bc
Restore accidently deleted comment line
Andrei-Aksionov Feb 22, 2024
1a2f9f8
Prompt for Gemma it (instruct models)
Andrei-Aksionov Feb 22, 2024
d002695
RMSNorm: reduce computations when self.add_unit_offset is False
Andrei-Aksionov Feb 23, 2024
c915d57
Auto markdown formatting
Andrei-Aksionov Feb 23, 2024
32260e6
Drop `tie_weights` in convert_hf
Andrei-Aksionov Feb 23, 2024
78ad643
Comment explaining why head_size*num_head instead of n_embd
Andrei-Aksionov Feb 23, 2024
6f154ab
scale_wte_output --> scale_embeddings
Andrei-Aksionov Feb 23, 2024
cfe68bb
Config: drop `self.rmsnorm_add_unit_offset`
Andrei-Aksionov Feb 23, 2024
57db710
Comment why do we need a unit offset in RMSNorm
Andrei-Aksionov Feb 23, 2024
4c44085
Bump up min version of transformers in github CI
Andrei-Aksionov Feb 23, 2024
c1dc9d2
Merge branch 'main' into gemma
Andrei-Aksionov Feb 23, 2024
0fe9b3f
Merge branch 'main' into gemma
Andrei-Aksionov Feb 23, 2024
e9b0c5a
Update convert_hf test
Andrei-Aksionov Feb 23, 2024
d17bb34
Update lit_gpt/adapter_v2.py
Andrei-Aksionov Feb 23, 2024
86b5b7a
Update lit_gpt/lora.py
Andrei-Aksionov Feb 23, 2024
8854d14
Update lit_gpt/model.py
Andrei-Aksionov Feb 23, 2024
d219965
Bump up min transformers version in Azure workflow
Andrei-Aksionov Feb 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions lit_gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,49 @@ def norm_class(self) -> Type:
configs.append(copy)


###############
# Google Gemma
###############
gemma = [
# https://huggingface.co/google/gemma-7b/blob/main/config.json
dict(
name="Gemma-7b-hf",
hf_config=dict(org="google", name="gemma-7b"),
vocab_size=256000,
padding_multiple=64,
n_embd=3072,
n_layer=28,
n_head=16,
n_query_groups=1,
Andrei-Aksionov marked this conversation as resolved.
Show resolved Hide resolved
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
_mlp_class="GemmaMLP",
intermediate_size=24576,
),
# https://huggingface.co/google/gemma-2b/blob/main/config.json
dict(
name="Gemma-2b-hf",
hf_config=dict(org="google", name="gemma-2b"),
vocab_size=256000,
padding_multiple=64,
n_embd=2048,
n_layer=18,
n_head=8,
n_query_groups=1,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
_mlp_class="GemmaMLP",
intermediate_size=16384,
),
]
configs.extend(gemma)



##########################
# Stability AI FreeWilly2
##########################
Expand Down
34 changes: 34 additions & 0 deletions lit_gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
import torch.nn as nn
from torch import Tensor
from typing_extensions import Self

from lit_gpt.config import Config
Expand Down Expand Up @@ -290,6 +291,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)


class GEGLU(nn.Module):
"""
Source: https://github.com/pfnet-research/deep-table/blob/237c8be8a405349ce6ab78075234c60d9bfe60b7/deep_table/nn/layers/activation.py#L22
License: MIT, https://github.com/pfnet-research/deep-table/blob/237c8be8a405349ce6ab78075234c60d9bfe60b7/LICENSE
References:
Shazeer et al., "GLU Variants Improve Transformer," 2020.
https://arxiv.org/abs/2002.05202
"""

def geglu(self, x: Tensor) -> Tensor:
assert x.shape[-1] % 2 == 0
a, b = x.chunk(2, dim=-1)
return a * torch.nn.functional.gelu(b)

def forward(self, x: Tensor) -> Tensor:
return self.geglu(x)
carmocca marked this conversation as resolved.
Show resolved Hide resolved


class GemmaMLP(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
self.geglu = torch.nn.GELU(approximate=True) # GEGLU()

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = self.geglu(x_fc_1) * x_fc_2
return self.proj(x)


class LLaMAMoE(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
Expand Down
9 changes: 7 additions & 2 deletions scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def copy_weights_hf_llama(
"model.layers.{}.block_sparse_moe.experts.{}.w3.weight": "transformer.h.{l}.mlp.experts.{e}.fc_2.weight",
"model.layers.{}.block_sparse_moe.experts.{}.w2.weight": "transformer.h.{l}.mlp.experts.{e}.proj.weight",
})
elif config._mlp_class == "LLaMAMLP":
elif config._mlp_class in ("LLaMAMLP", "GemmaMLP"):
weight_map.update({
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{l}.mlp.fc_1.weight",
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{l}.mlp.fc_2.weight",
Expand Down Expand Up @@ -171,6 +171,11 @@ def copy_weights_hf_llama(
param = saver.store_early(param)
state_dict[to_name] = param

# If model uses weight tying:
if "lm_head.weight" not in state_dict.keys():
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]


for i, (q, k, v) in list(qkv_weights.items()):
if q is None or k is None or v is None:
# split across different .bin files
Expand Down Expand Up @@ -299,7 +304,7 @@ def convert_hf_checkpoint(

if "falcon" in model_name:
copy_fn = partial(copy_weights_falcon, model_name)
elif config._mlp_class in ("LLaMAMLP", "LLaMAMoE"):
elif config._mlp_class in ("LLaMAMLP", "LLaMAMoE", "GemmaMLP"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
Expand Down
Loading