Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gemma checkpoint support #941

Merged
merged 37 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
034afac
gemma
rasbt Feb 21, 2024
d26ab02
add docs
rasbt Feb 21, 2024
5c1b029
update query head config
rasbt Feb 21, 2024
d77bd4a
apply keras geglu workaround
rasbt Feb 21, 2024
5ae2ad6
Carlos
carmocca Feb 21, 2024
e04be8a
An unfinished, but working 2b variant.
Andrei-Aksionov Feb 21, 2024
e61c449
Gemma-7b now works.
Andrei-Aksionov Feb 22, 2024
9a1f23c
add instruction-finetuned version
rasbt Feb 22, 2024
3ec329f
A test for config to check head_size
Andrei-Aksionov Feb 22, 2024
29a00c2
Update Gemma config
Andrei-Aksionov Feb 22, 2024
bf9d711
Adapter_v2 and LoRA: attn.proj size is head_size * num_heads
Andrei-Aksionov Feb 22, 2024
072c9f6
Adapter_v2 and LoRA: gemmamlp class
Andrei-Aksionov Feb 22, 2024
bed71f1
RMSNorm: unit offset is configurable
Andrei-Aksionov Feb 22, 2024
ec7d01e
Configurable wte output scaling
Andrei-Aksionov Feb 22, 2024
bd0864c
Update tests to supports changes in Config class
Andrei-Aksionov Feb 22, 2024
bf8c9b5
Test for Gemma
Andrei-Aksionov Feb 22, 2024
c78bd6e
conver_hf: reuse llama copy function
Andrei-Aksionov Feb 22, 2024
50ad509
Test Gemma model: use llama weights copying
Andrei-Aksionov Feb 22, 2024
d98df3b
Update convert_lit + test
Andrei-Aksionov Feb 22, 2024
159df75
Merge branch 'main' into gemma
Andrei-Aksionov Feb 22, 2024
628c7bc
Restore accidently deleted comment line
Andrei-Aksionov Feb 22, 2024
1a2f9f8
Prompt for Gemma it (instruct models)
Andrei-Aksionov Feb 22, 2024
d002695
RMSNorm: reduce computations when self.add_unit_offset is False
Andrei-Aksionov Feb 23, 2024
c915d57
Auto markdown formatting
Andrei-Aksionov Feb 23, 2024
32260e6
Drop `tie_weights` in convert_hf
Andrei-Aksionov Feb 23, 2024
78ad643
Comment explaining why head_size*num_head instead of n_embd
Andrei-Aksionov Feb 23, 2024
6f154ab
scale_wte_output --> scale_embeddings
Andrei-Aksionov Feb 23, 2024
cfe68bb
Config: drop `self.rmsnorm_add_unit_offset`
Andrei-Aksionov Feb 23, 2024
57db710
Comment why do we need a unit offset in RMSNorm
Andrei-Aksionov Feb 23, 2024
4c44085
Bump up min version of transformers in github CI
Andrei-Aksionov Feb 23, 2024
c1dc9d2
Merge branch 'main' into gemma
Andrei-Aksionov Feb 23, 2024
0fe9b3f
Merge branch 'main' into gemma
Andrei-Aksionov Feb 23, 2024
e9b0c5a
Update convert_hf test
Andrei-Aksionov Feb 23, 2024
d17bb34
Update lit_gpt/adapter_v2.py
Andrei-Aksionov Feb 23, 2024
86b5b7a
Update lit_gpt/lora.py
Andrei-Aksionov Feb 23, 2024
8854d14
Update lit_gpt/model.py
Andrei-Aksionov Feb 23, 2024
d219965
Bump up min transformers version in Azure workflow
Andrei-Aksionov Feb 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Supports the following popular model checkpoints:
| [Falcon](tutorials/download_falcon.md) by TII UAE | 7B, 40B, 180B | [TII 2023](https://falconllm.tii.ae) |
| [FreeWilly2](tutorials/download_freewilly_2.md) (Stable Beluga 2) by Stability AI | 70B | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| [Function Calling Llama 2](tutorials/download_function_calling_llama_2.md) by Trelis | 7B | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
| [Gemma](tutorials/download_gemma.md) by Google | 2B, 7B | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) |
| [Llama 2](tutorials/download_llama_2.md) by Meta AI | 7B, 13B, 70B | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| [LongChat](tutorials/download_longchat.md) by LMSYS | 7B, 13B | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
| [Mistral and Mixtral](tutorials/download_mistral.md) by Mistral AI | 7B | [Mistral website](https://mistral.ai/) |
Expand All @@ -46,7 +47,7 @@ Supports the following popular model checkpoints:
| [StableLM](tutorials/download_stablelm.md) by Stability AI | 3B, 7B | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
| [StableLM Zephyr](tutorials/download_stablelm.md) by Stability AI | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| [TinyLlama](tutorials/download_tinyllama.md) by Zhang et al. | 1.1B | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) |
| [Vicuna](tutorials/download_vicuna.md) by LMSYS | 7B, 13B, 33B | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) |
| [Vicuna](tutorials/download_vicuna.md) by LMSYS | 7B, 13B, 33B | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) |

This implementation extends on [Lit-LLaMA](https://github.com/lightning-AI/lit-llama) and [nanoGPT](https://github.com/karpathy/nanoGPT), and it's **powered by [Lightning Fabric](https://lightning.ai/docs/fabric/stable/) ⚡**.

Expand Down
41 changes: 41 additions & 0 deletions lit_gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,47 @@ def norm_class(self) -> Type:
configs.append(copy)


###############
# Google Gemma
###############
gemma = [
# https://huggingface.co/google/gemma-7b/blob/main/config.json
dict(
name="Gemma-7b-hf",
hf_config=dict(org="google", name="gemma-7b"),
vocab_size=256000,
padding_multiple=64,
n_embd=3072,
n_layer=28,
n_head=16,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
_mlp_class="GemmaMLP",
intermediate_size=24576,
),
# https://huggingface.co/google/gemma-2b/blob/main/config.json
dict(
name="Gemma-2b-hf",
hf_config=dict(org="google", name="gemma-2b"),
vocab_size=256000,
padding_multiple=64,
n_embd=2048,
n_layer=18,
n_head=8,
n_query_groups=1,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
_mlp_class="GemmaMLP",
intermediate_size=16384,
),
]
configs.extend(gemma)


##########################
# Stability AI FreeWilly2
##########################
Expand Down
14 changes: 14 additions & 0 deletions lit_gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -
mask = None

x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
# NOTE: this is a secret sauce (Gemma)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
x = x * (self.config.n_embd**0.5)

for block in self.transformer.h:
x = block(x, cos, sin, mask, input_pos)
x = self.transformer.ln_f(x)
Expand Down Expand Up @@ -290,6 +293,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)


class GemmaMLP(LLaMAMLP):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# the intermediate size for fc_{1,2} is halved when compared to LLaMAMLP, thus implementing GeGLU
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
# x = torch.nn.functional.gelu(x_fc_1, approximate="tanh") * x_fc_2
# NOTE: in HF they don't use approximation
carmocca marked this conversation as resolved.
Show resolved Hide resolved
x = torch.nn.functional.gelu(x_fc_1) * x_fc_2
return self.proj(x)


class LLaMAMoE(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
Expand Down
10 changes: 8 additions & 2 deletions lit_gpt/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,25 @@ class RMSNorm(torch.nn.Module):
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
"""

def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
# TODO: make `add_unit_offset` be dependent by a config
# def __init__(self, size: int, dim: int = -1, eps: float = 1e-5, add_unit_offset: bool = True) -> None:
def __init__(self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = True) -> None:
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(size))
self.eps = eps
self.dim = dim
self.add_unit_offset = add_unit_offset

# NOTE: output now closer to the official gemma implementation
# https://github.com/google/gemma_pytorch/blob/ca890c7abaa41ce7ab0eeda9aa8a52c0796b3a16/gemma/model.py#L170-L179
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
x = x.float()
# NOTE: the original RMSNorm paper implementation is not equivalent
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + self.eps)
return (self.weight * x_normed).to(dtype=dtype)
x_normed = x_normed.to(dtype=dtype)
return x_normed * (self.add_unit_offset + self.weight)
Andrei-Aksionov marked this conversation as resolved.
Show resolved Hide resolved

def reset_parameters(self) -> None:
torch.nn.init.ones_(self.weight)
76 changes: 75 additions & 1 deletion scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,78 @@ def copy_weights_hf_llama(
param = saver.store_early(param)
state_dict[to_name] = param

# convert separate q, k, v matrices into an interleaved qkv
for i, (q, k, v) in list(qkv_weights.items()):
if q is None or k is None or v is None:
# split across different .bin files
continue
q = load_param(q, f"layer {i} q", dtype)
k = load_param(k, f"layer {i} k", dtype)
v = load_param(v, f"layer {i} v", dtype)
q_per_kv = config.n_head // config.n_query_groups
qs = torch.split(q, config.head_size * q_per_kv)
ks = torch.split(k, config.head_size)
vs = torch.split(v, config.head_size)
cycled = [t for group in zip(qs, ks, vs) for t in group]
qkv = torch.cat(cycled)
state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv
del qkv_weights[i]


# TODO: probably we can simply reuse Llama weights copy
carmocca marked this conversation as resolved.
Show resolved Hide resolved
def copy_weights_hf_gemma(
config: Config,
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
weight_map = {
"model.embed_tokens.weight": "transformer.wte.weight",
"model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
"model.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias",
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
"model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias",
"model.norm.weight": "transformer.ln_f.weight",
"model.norm.bias": "transformer.ln_f.bias",
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
}
if config.bias:
raise NotImplementedError("bias halving not implemented")

for name, param in hf_weights.items():
if "model.layers" in name:
from_name, l = layer_template(name, 2)
qkv = qkv_weights.setdefault(l, [None, None, None])
if "q_proj" in name:
qkv[0] = param
elif "k_proj" in name:
qkv[1] = param
elif "v_proj" in name:
qkv[2] = param
to_name = weight_map[from_name]
if to_name is None:
continue
to_name = to_name.format(l)
else:
to_name = weight_map[name]
param = load_param(param, name, dtype)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param

# weight tying
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]

# convert separate q, k, v matrices into an interleaved qkv
for i, (q, k, v) in list(qkv_weights.items()):
if q is None or k is None or v is None:
# split across different .bin files
Expand Down Expand Up @@ -304,9 +376,11 @@ def convert_hf_checkpoint(
qkv_weights = {}
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
elif "phi" in model_name:
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_phi, config, qkv_weights)
elif config._mlp_class == "GemmaMLP":
qkv_weights = {}
copy_fn = partial(copy_weights_hf_gemma, config, qkv_weights)
else:
copy_fn = copy_weights_gpt_neox

Expand Down
41 changes: 41 additions & 0 deletions tutorials/download_gemma.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
## Download [Gemma](https://blog.google/technology/developers/gemma-open-models/) weights

Google developed and publicly released the Gemma large language models (LLMs), a collection of pretrained models in 2B and 7B parameter size that are based on the Gemini architecture.

For more information, please see the [technical report](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf).


To see all the available checkpoints, run:

```bash
python scripts/download.py | grep gemma
```

which will print

```text
google/gemma-2b
google/gemma-7b
```

In order to use a specific checkpoint, for instance [gemma-2b](https://huggingface.co/google/gemma-2b), download the weights and convert the checkpoint to the lit-gpt format.

This requires that you've been granted access to the weights on the HuggingFace hub. You can do so by following the steps at <https://huggingface.co/google/gemma-2b>.
After access is granted, you can find your HF hub token in <https://huggingface.co/settings/tokens>.

```bash
pip install 'huggingface_hub[hf_transfer] @ git+https://github.com/huggingface/huggingface_hub'

python scripts/download.py --repo_id google/gemma-2b --access_token your_hf_token --from_safetensors true

python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/google/gemma-2b
```

By default, the `convert_hf_checkpoint` step will use the data type of the HF checkpoint's parameters. In cases where RAM
or disk size is constrained, it might be useful to pass `--dtype bfloat16` to convert all parameters into this smaller precision before continuing.

You're done! To execute the model just run:

```bash
python chat/base.py --checkpoint_dir checkpoints/google/gemma-2b
```
Loading