Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gemma checkpoint support #941

Merged
merged 37 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
034afac
gemma
rasbt Feb 21, 2024
d26ab02
add docs
rasbt Feb 21, 2024
5c1b029
update query head config
rasbt Feb 21, 2024
d77bd4a
apply keras geglu workaround
rasbt Feb 21, 2024
5ae2ad6
Carlos
carmocca Feb 21, 2024
e04be8a
An unfinished, but working 2b variant.
Andrei-Aksionov Feb 21, 2024
e61c449
Gemma-7b now works.
Andrei-Aksionov Feb 22, 2024
9a1f23c
add instruction-finetuned version
rasbt Feb 22, 2024
3ec329f
A test for config to check head_size
Andrei-Aksionov Feb 22, 2024
29a00c2
Update Gemma config
Andrei-Aksionov Feb 22, 2024
bf9d711
Adapter_v2 and LoRA: attn.proj size is head_size * num_heads
Andrei-Aksionov Feb 22, 2024
072c9f6
Adapter_v2 and LoRA: gemmamlp class
Andrei-Aksionov Feb 22, 2024
bed71f1
RMSNorm: unit offset is configurable
Andrei-Aksionov Feb 22, 2024
ec7d01e
Configurable wte output scaling
Andrei-Aksionov Feb 22, 2024
bd0864c
Update tests to supports changes in Config class
Andrei-Aksionov Feb 22, 2024
bf8c9b5
Test for Gemma
Andrei-Aksionov Feb 22, 2024
c78bd6e
conver_hf: reuse llama copy function
Andrei-Aksionov Feb 22, 2024
50ad509
Test Gemma model: use llama weights copying
Andrei-Aksionov Feb 22, 2024
d98df3b
Update convert_lit + test
Andrei-Aksionov Feb 22, 2024
159df75
Merge branch 'main' into gemma
Andrei-Aksionov Feb 22, 2024
628c7bc
Restore accidently deleted comment line
Andrei-Aksionov Feb 22, 2024
1a2f9f8
Prompt for Gemma it (instruct models)
Andrei-Aksionov Feb 22, 2024
d002695
RMSNorm: reduce computations when self.add_unit_offset is False
Andrei-Aksionov Feb 23, 2024
c915d57
Auto markdown formatting
Andrei-Aksionov Feb 23, 2024
32260e6
Drop `tie_weights` in convert_hf
Andrei-Aksionov Feb 23, 2024
78ad643
Comment explaining why head_size*num_head instead of n_embd
Andrei-Aksionov Feb 23, 2024
6f154ab
scale_wte_output --> scale_embeddings
Andrei-Aksionov Feb 23, 2024
cfe68bb
Config: drop `self.rmsnorm_add_unit_offset`
Andrei-Aksionov Feb 23, 2024
57db710
Comment why do we need a unit offset in RMSNorm
Andrei-Aksionov Feb 23, 2024
4c44085
Bump up min version of transformers in github CI
Andrei-Aksionov Feb 23, 2024
c1dc9d2
Merge branch 'main' into gemma
Andrei-Aksionov Feb 23, 2024
0fe9b3f
Merge branch 'main' into gemma
Andrei-Aksionov Feb 23, 2024
e9b0c5a
Update convert_hf test
Andrei-Aksionov Feb 23, 2024
d17bb34
Update lit_gpt/adapter_v2.py
Andrei-Aksionov Feb 23, 2024
86b5b7a
Update lit_gpt/lora.py
Andrei-Aksionov Feb 23, 2024
8854d14
Update lit_gpt/model.py
Andrei-Aksionov Feb 23, 2024
d219965
Bump up min transformers version in Azure workflow
Andrei-Aksionov Feb 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Supports the following popular model checkpoints:
| [Falcon](tutorials/download_falcon.md) by TII UAE | 7B, 40B, 180B | [TII 2023](https://falconllm.tii.ae) |
| [FreeWilly2](tutorials/download_freewilly_2.md) (Stable Beluga 2) by Stability AI | 70B | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| [Function Calling Llama 2](tutorials/download_function_calling_llama_2.md) by Trelis | 7B | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
| [Gemma](tutorials/download_gemma.md) by Google | 2B, 7B | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) |
| [Llama 2](tutorials/download_llama_2.md) by Meta AI | 7B, 13B, 70B | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| [LongChat](tutorials/download_longchat.md) by LMSYS | 7B, 13B | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
| [Mistral and Mixtral](tutorials/download_mistral.md) by Mistral AI | 7B | [Mistral website](https://mistral.ai/) |
Expand Down
5 changes: 5 additions & 0 deletions chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,11 @@ def prompt_config(checkpoint_dir: Path, tokenizer: Tokenizer) -> Tuple[str, Tupl
stop_tokens = ([tokenizer.eos_id],)
return system_prompt, stop_tokens

if re.search(r"gemma.*-it", checkpoint_name):
system_prompt = "<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
stop_tokens = ([tokenizer.eos_id],)
return system_prompt, stop_tokens

# default format
return "{prompt}", ([tokenizer.eos_id],)

Expand Down
10 changes: 9 additions & 1 deletion lit_gpt/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(self, config: Config, block_idx: int) -> None:
# key, query, value projections for all heads, but in a batch
self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias)
# output projection
self.proj = AdapterV2Linear(config.n_embd, config.n_embd, bias=config.bias)
self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# disabled by default
self.kv_cache: Optional[KVCache] = None

Expand Down Expand Up @@ -194,6 +194,14 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


class GemmaMLP(LLaMAMLP):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = torch.nn.functional.gelu(x_fc_1) * x_fc_2
return self.proj(x)


class LLaMAMoE(lit_gpt.model.LLaMAMoE):
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
Expand Down
65 changes: 61 additions & 4 deletions lit_gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
class Config:
name: str = ""
hf_config: dict = field(default_factory=dict)
scale_wte_output: bool = False
carmocca marked this conversation as resolved.
Show resolved Hide resolved
block_size: int = 4096
vocab_size: int = 50254
padding_multiple: int = 512
padded_vocab_size: Optional[int] = None
n_layer: int = 16
n_head: int = 32
head_size: Optional[int] = None
n_embd: int = 4096
rotary_percentage: float = 0.25
parallel_residual: bool = True
Expand Down Expand Up @@ -51,8 +53,9 @@ class Config:
n_query_groups: Optional[int] = None
shared_attention_norm: bool = False
_norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
rmsnorm_add_unit_offset: bool = False
carmocca marked this conversation as resolved.
Show resolved Hide resolved
norm_eps: float = 1e-5
_mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP"
_mlp_class: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP"
gelu_approximate: str = "none"
intermediate_size: Optional[int] = None
rope_condense_ratio: int = 1
Expand All @@ -64,8 +67,9 @@ def __post_init__(self):
if not self.name:
self.name = self.hf_config.get("name", self.name)

assert self.n_embd % self.n_head == 0
self.head_size = self.n_embd // self.n_head
if self.head_size is None:
assert self.n_embd % self.n_head == 0
self.head_size = self.n_embd // self.n_head

# vocab size should be a power of 2 to be optimal on hardware. compute the closest value
if self.padded_vocab_size is None:
Expand Down Expand Up @@ -138,9 +142,11 @@ def mlp_class(self) -> Type:
def norm_class(self) -> Type:
# `self._norm_class` cannot be the type to keep the config json serializable
if self._norm_class == "RMSNorm":
from functools import partial

from lit_gpt.rmsnorm import RMSNorm

return RMSNorm
return partial(RMSNorm, add_unit_offset=self.rmsnorm_add_unit_offset)
return getattr(torch.nn, self._norm_class)


Expand Down Expand Up @@ -781,6 +787,57 @@ def norm_class(self) -> Type:
configs.append(copy)


###############
# Google Gemma
###############
gemma = [
# https://huggingface.co/google/gemma-2b/blob/main/config.json
dict(
name="Gemma-2b",
hf_config=dict(org="google", name="gemma-2b"),
scale_wte_output=True,
vocab_size=256000,
padding_multiple=64,
n_embd=2048,
n_layer=18,
n_head=8,
n_query_groups=1,
Andrei-Aksionov marked this conversation as resolved.
Show resolved Hide resolved
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
rmsnorm_add_unit_offset=True,
_mlp_class="GemmaMLP",
intermediate_size=16384,
),
# https://huggingface.co/google/gemma-7b/blob/main/config.json
dict(
name="Gemma-7b",
hf_config=dict(org="google", name="gemma-7b"),
scale_wte_output=True,
vocab_size=256000,
padding_multiple=64,
n_embd=3072,
n_layer=28,
n_head=16,
head_size=256,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
rmsnorm_add_unit_offset=True,
_mlp_class="GemmaMLP",
intermediate_size=24576,
),
]
configs.extend(gemma)
for c in gemma:
copy = deepcopy(c)
copy["name"] = f"{c['name']}-it"
copy["hf_config"]["name"] = f"{c['hf_config']['name']}-it"
configs.append(copy)


##########################
# Stability AI FreeWilly2
##########################
Expand Down
10 changes: 9 additions & 1 deletion lit_gpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def __init__(self, config: Config) -> None:
)
# output projection
self.proj = LoRALinear(
config.n_embd,
config.head_size * config.n_head,
config.n_embd,
bias=config.bias,
r=(config.r if config.to_projection else 0),
Expand Down Expand Up @@ -699,6 +699,14 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


class GemmaMLP(LLaMAMLP):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = torch.nn.functional.gelu(x_fc_1) * x_fc_2
return self.proj(x)


class LLaMAMoE(lit_gpt.model.LLaMAMoE):
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
Expand Down
15 changes: 13 additions & 2 deletions lit_gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -
mask = None

x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
if self.config.scale_wte_output:
x = x * (self.config.n_embd**0.5)

for block in self.transformer.h:
x = block(x, cos, sin, mask, input_pos)
x = self.transformer.ln_f(x)
Expand Down Expand Up @@ -174,7 +177,7 @@ def __init__(self, config: Config) -> None:
# key, query, value projections for all heads, but in a batch
self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
# output projection
self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)
# disabled by default
self.kv_cache: Optional[KVCache] = None

Expand Down Expand Up @@ -224,7 +227,7 @@ def forward(

y = self.scaled_dot_product_attention(q, k, v, mask)

y = y.reshape(B, T, self.config.n_embd) # re-assemble all head outputs side by side
y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side

# output projection
return self.proj(y)
Expand Down Expand Up @@ -290,6 +293,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)


class GemmaMLP(LLaMAMLP):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = torch.nn.functional.gelu(x_fc_1) * x_fc_2
return self.proj(x)


class LLaMAMoE(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
Expand Down
8 changes: 6 additions & 2 deletions lit_gpt/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,23 @@ class RMSNorm(torch.nn.Module):
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
"""

def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
def __init__(self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False) -> None:
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(size))
self.eps = eps
self.dim = dim
self.add_unit_offset = add_unit_offset

def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
x = x.float()
# NOTE: the original RMSNorm paper implementation is not equivalent
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + self.eps)
return (self.weight * x_normed).to(dtype=dtype)
x_normed = x_normed.to(dtype=dtype)
if self.add_unit_offset:
return x_normed * (1 + self.weight)
Andrei-Aksionov marked this conversation as resolved.
Show resolved Hide resolved
return x_normed * self.weight

def reset_parameters(self) -> None:
torch.nn.init.ones_(self.weight)
12 changes: 9 additions & 3 deletions scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def copy_weights_hf_llama(
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
tie_weights: bool = False,
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
Expand Down Expand Up @@ -144,7 +145,7 @@ def copy_weights_hf_llama(
"model.layers.{}.block_sparse_moe.experts.{}.w2.weight": "transformer.h.{l}.mlp.experts.{e}.proj.weight",
}
)
elif config._mlp_class == "LLaMAMLP":
elif config._mlp_class in ("LLaMAMLP", "GemmaMLP"):
weight_map.update(
{
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{l}.mlp.fc_1.weight",
Expand Down Expand Up @@ -179,6 +180,10 @@ def copy_weights_hf_llama(
param = saver.store_early(param)
state_dict[to_name] = param

if tie_weights:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]

# convert separate q, k, v matrices into an interleaved qkv
for i, (q, k, v) in list(qkv_weights.items()):
if q is None or k is None or v is None:
# split across different .bin files
Expand Down Expand Up @@ -307,10 +312,11 @@ def convert_hf_checkpoint(

if "falcon" in model_name:
copy_fn = partial(copy_weights_falcon, model_name)
elif config._mlp_class in ("LLaMAMLP", "LLaMAMoE"):
elif config._mlp_class in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
tie_weights = "Gemma" in config.name
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights, tie_weights=tie_weights)
elif "phi" in model_name:
# holder to reconstitute the split q, k, v
qkv_weights = {}
Expand Down
10 changes: 7 additions & 3 deletions scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def copy_weights_llama(
config: Config,
state_dict: Dict[str, torch.Tensor],
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
untie_weights: bool = False,
saver: Optional[incremental_save] = None,
) -> None:
weight_map = {
Expand All @@ -128,7 +129,7 @@ def copy_weights_llama(
"transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w2.weight",
}
)
elif config._mlp_class == "LLaMAMLP":
elif config._mlp_class in ("LLaMAMLP", "GemmaMLP"):
weight_map.update(
{
"transformer.h.{}.mlp.fc_1.weight": "model.layers.{l}.mlp.gate_proj.weight",
Expand All @@ -140,6 +141,8 @@ def copy_weights_llama(
raise NotImplementedError

for name, param in lit_weights.items():
if name == "lm_head.weight" and untie_weights:
continue
if name.endswith(".attn.attn.weight"):
from_name, l = layer_template(name, 2)
q = "model.layers.{}.self_attn.q_proj.weight".format(l)
Expand Down Expand Up @@ -246,8 +249,9 @@ def convert_lit_checkpoint(checkpoint_path: Path, output_path: Path, config_path

if "falcon" in config.name:
copy_fn = partial(copy_weights_falcon, config.name)
elif config._mlp_class in ("LLaMAMLP", "LLaMAMoE"):
copy_fn = partial(copy_weights_llama, config)
elif config._mlp_class in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
untie_weights = "Gemma" in config.name
copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights)
elif "phi" in config.name:
copy_fn = partial(copy_weights_phi, config)
else:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,12 @@ def test_from_checkpoint(tmp_path):
assert config.name == "pythia-14m"
assert config.block_size == 24
assert config.n_layer == 2


@pytest.mark.parametrize("head_size", [None, 128])
def test_head_size(head_size):
from lit_gpt import Config

config = Config(head_size)

assert config.head_size == head_size or config.n_embd // config.n_head
Loading
Loading