From 034afacae63c402007efc04d65ffe072dba48b05 Mon Sep 17 00:00:00 2001 From: rasbt Date: Wed, 21 Feb 2024 15:53:12 +0000 Subject: [PATCH 01/34] gemma --- lit_gpt/config.py | 43 ++++++++++++++++++++++++++++++++ lit_gpt/model.py | 34 +++++++++++++++++++++++++ scripts/convert_hf_checkpoint.py | 9 +++++-- 3 files changed, 84 insertions(+), 2 deletions(-) diff --git a/lit_gpt/config.py b/lit_gpt/config.py index 439bca71ca..399b24c280 100644 --- a/lit_gpt/config.py +++ b/lit_gpt/config.py @@ -781,6 +781,49 @@ def norm_class(self) -> Type: configs.append(copy) +############### +# Google Gemma +############### +gemma = [ + # https://huggingface.co/google/gemma-7b/blob/main/config.json + dict( + name="Gemma-7b-hf", + hf_config=dict(org="google", name="gemma-7b"), + vocab_size=256000, + padding_multiple=64, + n_embd=3072, + n_layer=28, + n_head=16, + n_query_groups=1, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="GemmaMLP", + intermediate_size=24576, + ), + # https://huggingface.co/google/gemma-2b/blob/main/config.json + dict( + name="Gemma-2b-hf", + hf_config=dict(org="google", name="gemma-2b"), + vocab_size=256000, + padding_multiple=64, + n_embd=2048, + n_layer=18, + n_head=8, + n_query_groups=1, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="GemmaMLP", + intermediate_size=16384, + ), +] +configs.extend(gemma) + + + ########################## # Stability AI FreeWilly2 ########################## diff --git a/lit_gpt/model.py b/lit_gpt/model.py index 227251a863..402b2c0d72 100644 --- a/lit_gpt/model.py +++ b/lit_gpt/model.py @@ -11,6 +11,7 @@ import torch import torch.nn as nn +from torch import Tensor from typing_extensions import Self from lit_gpt.config import Config @@ -290,6 +291,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj(x) +class GEGLU(nn.Module): + """ + Source: https://github.com/pfnet-research/deep-table/blob/237c8be8a405349ce6ab78075234c60d9bfe60b7/deep_table/nn/layers/activation.py#L22 + License: MIT, https://github.com/pfnet-research/deep-table/blob/237c8be8a405349ce6ab78075234c60d9bfe60b7/LICENSE + References: + Shazeer et al., "GLU Variants Improve Transformer," 2020. + https://arxiv.org/abs/2002.05202 + """ + + def geglu(self, x: Tensor) -> Tensor: + assert x.shape[-1] % 2 == 0 + a, b = x.chunk(2, dim=-1) + return a * torch.nn.functional.gelu(b) + + def forward(self, x: Tensor) -> Tensor: + return self.geglu(x) + + +class GemmaMLP(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) + self.geglu = torch.nn.GELU(approximate=True) # GEGLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fc_1 = self.fc_1(x) + x_fc_2 = self.fc_2(x) + x = self.geglu(x_fc_1) * x_fc_2 + return self.proj(x) + + class LLaMAMoE(nn.Module): def __init__(self, config: Config) -> None: super().__init__() diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 898e585a0b..0332858932 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -138,7 +138,7 @@ def copy_weights_hf_llama( "model.layers.{}.block_sparse_moe.experts.{}.w3.weight": "transformer.h.{l}.mlp.experts.{e}.fc_2.weight", "model.layers.{}.block_sparse_moe.experts.{}.w2.weight": "transformer.h.{l}.mlp.experts.{e}.proj.weight", }) - elif config._mlp_class == "LLaMAMLP": + elif config._mlp_class in ("LLaMAMLP", "GemmaMLP"): weight_map.update({ "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{l}.mlp.fc_1.weight", "model.layers.{}.mlp.up_proj.weight": "transformer.h.{l}.mlp.fc_2.weight", @@ -171,6 +171,11 @@ def copy_weights_hf_llama( param = saver.store_early(param) state_dict[to_name] = param + # If model uses weight tying: + if "lm_head.weight" not in state_dict.keys(): + state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] + + for i, (q, k, v) in list(qkv_weights.items()): if q is None or k is None or v is None: # split across different .bin files @@ -299,7 +304,7 @@ def convert_hf_checkpoint( if "falcon" in model_name: copy_fn = partial(copy_weights_falcon, model_name) - elif config._mlp_class in ("LLaMAMLP", "LLaMAMoE"): + elif config._mlp_class in ("LLaMAMLP", "LLaMAMoE", "GemmaMLP"): # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) From d26ab02e8a2d2f6706ca4ca47e002a87c1d575f0 Mon Sep 17 00:00:00 2001 From: rasbt Date: Wed, 21 Feb 2024 16:08:26 +0000 Subject: [PATCH 02/34] add docs --- README.md | 1 + tutorials/download_gemma.md | 41 +++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 tutorials/download_gemma.md diff --git a/README.md b/README.md index e7ff96f133..206b01ca8d 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ Supports the following popular model checkpoints: | [Falcon](tutorials/download_falcon.md) by TII UAE | 7B, 40B, 180B | [TII 2023](https://falconllm.tii.ae) | | [FreeWilly2](tutorials/download_freewilly_2.md) (Stable Beluga 2) by Stability AI | 70B | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | | [Function Calling Llama 2](tutorials/download_function_calling_llama_2.md) by Trelis | 7B | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | +| [Gemma](tutorials/download_gemma.md) by Google | 2B, 7B | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | | [Llama 2](tutorials/download_llama_2.md) by Meta AI | 7B, 13B, 70B | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | | [LongChat](tutorials/download_longchat.md) by LMSYS | 7B, 13B | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | | [Mistral and Mixtral](tutorials/download_mistral.md) by Mistral AI | 7B | [Mistral website](https://mistral.ai/) | diff --git a/tutorials/download_gemma.md b/tutorials/download_gemma.md new file mode 100644 index 0000000000..a344051809 --- /dev/null +++ b/tutorials/download_gemma.md @@ -0,0 +1,41 @@ +## Download [Gemma](https://blog.google/technology/developers/gemma-open-models/) weights + +Google developed and publicly released the Gemma large language models (LLMs), a collection of pretrained models in 2B and 7B parameter size that are based on the Gemini architecture. + +For more information, please see the [technical report](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf). + + +To see all the available checkpoints, run: + +```bash +python scripts/download.py | grep gemma +``` + +which will print + +```text +google/gemma-2b +google/gemma-7b +``` + +In order to use a specific checkpoint, for instance [gemma-2b](https://huggingface.co/google/gemma-2b), download the weights and convert the checkpoint to the lit-gpt format. + +This requires that you've been granted access to the weights on the HuggingFace hub. You can do so by following the steps at . +After access is granted, you can find your HF hub token in . + +```bash +pip install 'huggingface_hub[hf_transfer] @ git+https://github.com/huggingface/huggingface_hub' + +python scripts/download.py --repo_id checkpoints/google/gemma-2b --access_token your_hf_token --use_safetensors true + +python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/google/gemma-2b +``` + +By default, the `convert_hf_checkpoint` step will use the data type of the HF checkpoint's parameters. In cases where RAM +or disk size is constrained, it might be useful to pass `--dtype bfloat16` to convert all parameters into this smaller precision before continuing. + +You're done! To execute the model just run: + +```bash +python chat/base.py --checkpoint_dir checkpoints/google/gemma-2b +``` From 5c1b02954b7c19e2a0674dffbf3826f419a8d658 Mon Sep 17 00:00:00 2001 From: rasbt Date: Wed, 21 Feb 2024 16:31:30 +0000 Subject: [PATCH 03/34] update query head config --- lit_gpt/config.py | 2 +- lit_gpt/model.py | 22 ++-------------------- 2 files changed, 3 insertions(+), 21 deletions(-) diff --git a/lit_gpt/config.py b/lit_gpt/config.py index 399b24c280..d6263e3e2d 100644 --- a/lit_gpt/config.py +++ b/lit_gpt/config.py @@ -794,7 +794,7 @@ def norm_class(self) -> Type: n_embd=3072, n_layer=28, n_head=16, - n_query_groups=1, + n_query_groups=16, rotary_percentage=1.0, parallel_residual=False, bias=False, diff --git a/lit_gpt/model.py b/lit_gpt/model.py index 402b2c0d72..3c958bc94c 100644 --- a/lit_gpt/model.py +++ b/lit_gpt/model.py @@ -291,31 +291,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj(x) -class GEGLU(nn.Module): - """ - Source: https://github.com/pfnet-research/deep-table/blob/237c8be8a405349ce6ab78075234c60d9bfe60b7/deep_table/nn/layers/activation.py#L22 - License: MIT, https://github.com/pfnet-research/deep-table/blob/237c8be8a405349ce6ab78075234c60d9bfe60b7/LICENSE - References: - Shazeer et al., "GLU Variants Improve Transformer," 2020. - https://arxiv.org/abs/2002.05202 - """ - - def geglu(self, x: Tensor) -> Tensor: - assert x.shape[-1] % 2 == 0 - a, b = x.chunk(2, dim=-1) - return a * torch.nn.functional.gelu(b) - - def forward(self, x: Tensor) -> Tensor: - return self.geglu(x) - - class GemmaMLP(nn.Module): def __init__(self, config: Config) -> None: super().__init__() self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) - self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size , bias=config.bias) self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) - self.geglu = torch.nn.GELU(approximate=True) # GEGLU() + self.geglu = torch.nn.GELU(approximate="tanh") def forward(self, x: torch.Tensor) -> torch.Tensor: x_fc_1 = self.fc_1(x) From d77bd4a5ddd7ec0299743753ab7d0a5e674351fc Mon Sep 17 00:00:00 2001 From: rasbt Date: Wed, 21 Feb 2024 16:55:37 +0000 Subject: [PATCH 04/34] apply keras geglu workaround --- lit_gpt/config.py | 4 ++-- lit_gpt/model.py | 12 +++--------- scripts/convert_hf_checkpoint.py | 8 ++++++++ 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/lit_gpt/config.py b/lit_gpt/config.py index d6263e3e2d..a4184c06ff 100644 --- a/lit_gpt/config.py +++ b/lit_gpt/config.py @@ -800,7 +800,7 @@ def norm_class(self) -> Type: bias=False, _norm_class="RMSNorm", _mlp_class="GemmaMLP", - intermediate_size=24576, + intermediate_size=24576 // 2, ), # https://huggingface.co/google/gemma-2b/blob/main/config.json dict( @@ -817,7 +817,7 @@ def norm_class(self) -> Type: bias=False, _norm_class="RMSNorm", _mlp_class="GemmaMLP", - intermediate_size=16384, + intermediate_size=16384 // 2, ), ] configs.extend(gemma) diff --git a/lit_gpt/model.py b/lit_gpt/model.py index 3c958bc94c..a89f0a169c 100644 --- a/lit_gpt/model.py +++ b/lit_gpt/model.py @@ -291,18 +291,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj(x) -class GemmaMLP(nn.Module): - def __init__(self, config: Config) -> None: - super().__init__() - self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) - self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size , bias=config.bias) - self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) - self.geglu = torch.nn.GELU(approximate="tanh") - +class GemmaMLP(LLaMAMLP): def forward(self, x: torch.Tensor) -> torch.Tensor: + # the intermediate size for fc_{1,2} is halved when compared to LLaMAMLP, thus implementing GeGLU x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) - x = self.geglu(x_fc_1) * x_fc_2 + x = torch.nn.functional.gelu(x_fc_1, approximate="tanh") * x_fc_2 return self.proj(x) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 0332858932..1d0180f633 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -167,6 +167,14 @@ def copy_weights_hf_llama( else: to_name = weight_map[name] param = load_param(param, name, dtype) + + if config._mlp_class == "GemmaMLP": + # select only half the FC layers to match the Gemma GeGLU workaround in the original Keras implementation + if "mlp.gate_proj" in name or "mlp.up_proj.weight" in name: + param = param[:param.size(0) // 2] + elif "mlp.down_proj" in name: + param = param[:, :param.size(1) // 2] + if saver is not None: param = saver.store_early(param) state_dict[to_name] = param From 5ae2ad6efc12a9ea402e78e36ccdb854773fe841 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 21 Feb 2024 18:09:25 +0100 Subject: [PATCH 05/34] Carlos --- README.md | 2 +- lit_gpt/config.py | 2 - lit_gpt/model.py | 1 - scripts/convert_hf_checkpoint.py | 95 +++++++++++++++++++++++++++----- 4 files changed, 82 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 206b01ca8d..f66c8d5f7a 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ Supports the following popular model checkpoints: | [StableLM](tutorials/download_stablelm.md) by Stability AI | 3B, 7B | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | | [StableLM Zephyr](tutorials/download_stablelm.md) by Stability AI | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | [TinyLlama](tutorials/download_tinyllama.md) by Zhang et al. | 1.1B | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | -| [Vicuna](tutorials/download_vicuna.md) by LMSYS | 7B, 13B, 33B | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) | +| [Vicuna](tutorials/download_vicuna.md) by LMSYS | 7B, 13B, 33B | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) | This implementation extends on [Lit-LLaMA](https://github.com/lightning-AI/lit-llama) and [nanoGPT](https://github.com/karpathy/nanoGPT), and it's **powered by [Lightning Fabric](https://lightning.ai/docs/fabric/stable/) ⚡**. diff --git a/lit_gpt/config.py b/lit_gpt/config.py index a4184c06ff..152a426041 100644 --- a/lit_gpt/config.py +++ b/lit_gpt/config.py @@ -794,7 +794,6 @@ def norm_class(self) -> Type: n_embd=3072, n_layer=28, n_head=16, - n_query_groups=16, rotary_percentage=1.0, parallel_residual=False, bias=False, @@ -823,7 +822,6 @@ def norm_class(self) -> Type: configs.extend(gemma) - ########################## # Stability AI FreeWilly2 ########################## diff --git a/lit_gpt/model.py b/lit_gpt/model.py index a89f0a169c..035b73151c 100644 --- a/lit_gpt/model.py +++ b/lit_gpt/model.py @@ -11,7 +11,6 @@ import torch import torch.nn as nn -from torch import Tensor from typing_extensions import Self from lit_gpt.config import Config diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 1d0180f633..e93796b9d8 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -138,7 +138,7 @@ def copy_weights_hf_llama( "model.layers.{}.block_sparse_moe.experts.{}.w3.weight": "transformer.h.{l}.mlp.experts.{e}.fc_2.weight", "model.layers.{}.block_sparse_moe.experts.{}.w2.weight": "transformer.h.{l}.mlp.experts.{e}.proj.weight", }) - elif config._mlp_class in ("LLaMAMLP", "GemmaMLP"): + elif config._mlp_class == "LLaMAMLP": weight_map.update({ "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{l}.mlp.fc_1.weight", "model.layers.{}.mlp.up_proj.weight": "transformer.h.{l}.mlp.fc_2.weight", @@ -167,23 +167,88 @@ def copy_weights_hf_llama( else: to_name = weight_map[name] param = load_param(param, name, dtype) - - if config._mlp_class == "GemmaMLP": - # select only half the FC layers to match the Gemma GeGLU workaround in the original Keras implementation - if "mlp.gate_proj" in name or "mlp.up_proj.weight" in name: - param = param[:param.size(0) // 2] - elif "mlp.down_proj" in name: - param = param[:, :param.size(1) // 2] - if saver is not None: param = saver.store_early(param) state_dict[to_name] = param - # If model uses weight tying: - if "lm_head.weight" not in state_dict.keys(): - state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] + # convert separate q, k, v matrices into an interleaved qkv + for i, (q, k, v) in list(qkv_weights.items()): + if q is None or k is None or v is None: + # split across different .bin files + continue + q = load_param(q, f"layer {i} q", dtype) + k = load_param(k, f"layer {i} k", dtype) + v = load_param(v, f"layer {i} v", dtype) + q_per_kv = config.n_head // config.n_query_groups + qs = torch.split(q, config.head_size * q_per_kv) + ks = torch.split(k, config.head_size) + vs = torch.split(v, config.head_size) + cycled = [t for group in zip(qs, ks, vs) for t in group] + qkv = torch.cat(cycled) + state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv + del qkv_weights[i] +def copy_weights_hf_gemma( + config: Config, + qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], + state_dict: Dict[str, torch.Tensor], + hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, + dtype: Optional[torch.dtype] = None, +) -> None: + weight_map = { + "model.embed_tokens.weight": "transformer.wte.weight", + "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", + "model.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", + "model.layers.{}.self_attn.q_proj.weight": None, + "model.layers.{}.self_attn.k_proj.weight": None, + "model.layers.{}.self_attn.v_proj.weight": None, + "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", + "model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias", + "model.norm.weight": "transformer.ln_f.weight", + "model.norm.bias": "transformer.ln_f.bias", + "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", + "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", + "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", + } + if config.bias: + raise NotImplementedError("bias halving not implemented") + + for name, param in hf_weights.items(): + if "model.layers" in name: + from_name, l = layer_template(name, 2) + qkv = qkv_weights.setdefault(l, [None, None, None]) + if "q_proj" in name: + qkv[0] = param + elif "k_proj" in name: + qkv[1] = param + elif "v_proj" in name: + qkv[2] = param + to_name = weight_map[from_name] + if to_name is None: + continue + to_name = to_name.format(l) + else: + to_name = weight_map[name] + param = load_param(param, name, dtype) + + # halve the mlp linears: they are double the size in the HF implementation compared to Keras + halving_name_2_dim = {"mlp.fc_1": 0, "mlp.fc_2": 0, "mlp.proj": 1} + for key, dim in halving_name_2_dim.items(): + if key in to_name: + param, _ = torch.chunk(param, 2, dim=dim) + + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + # weight tying + state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] + + # convert separate q, k, v matrices into an interleaved qkv for i, (q, k, v) in list(qkv_weights.items()): if q is None or k is None or v is None: # split across different .bin files @@ -312,14 +377,16 @@ def convert_hf_checkpoint( if "falcon" in model_name: copy_fn = partial(copy_weights_falcon, model_name) - elif config._mlp_class in ("LLaMAMLP", "LLaMAMoE", "GemmaMLP"): + elif config._mlp_class in ("LLaMAMLP", "LLaMAMoE"): # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) elif "phi" in model_name: - # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_phi, config, qkv_weights) + elif config._mlp_class == "GemmaMLP": + qkv_weights = {} + copy_fn = partial(copy_weights_hf_gemma, config, qkv_weights) else: copy_fn = copy_weights_gpt_neox From e04be8a2384d6719f1696274c5dc99748ae77066 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Wed, 21 Feb 2024 20:55:10 +0000 Subject: [PATCH 06/34] An unfinished, but working 2b variant. --- lit_gpt/config.py | 4 ++-- lit_gpt/model.py | 7 ++++++- lit_gpt/rmsnorm.py | 10 ++++++++-- scripts/convert_hf_checkpoint.py | 8 +------- tutorials/download_gemma.md | 2 +- 5 files changed, 18 insertions(+), 13 deletions(-) diff --git a/lit_gpt/config.py b/lit_gpt/config.py index 152a426041..4b2a96fed8 100644 --- a/lit_gpt/config.py +++ b/lit_gpt/config.py @@ -799,7 +799,7 @@ def norm_class(self) -> Type: bias=False, _norm_class="RMSNorm", _mlp_class="GemmaMLP", - intermediate_size=24576 // 2, + intermediate_size=24576, ), # https://huggingface.co/google/gemma-2b/blob/main/config.json dict( @@ -816,7 +816,7 @@ def norm_class(self) -> Type: bias=False, _norm_class="RMSNorm", _mlp_class="GemmaMLP", - intermediate_size=16384 // 2, + intermediate_size=16384, ), ] configs.extend(gemma) diff --git a/lit_gpt/model.py b/lit_gpt/model.py index 035b73151c..197fb2cce1 100644 --- a/lit_gpt/model.py +++ b/lit_gpt/model.py @@ -87,6 +87,9 @@ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) - mask = None x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + # NOTE: this is a secret sauce (Gemma) + x = x * (self.config.n_embd**0.5) + for block in self.transformer.h: x = block(x, cos, sin, mask, input_pos) x = self.transformer.ln_f(x) @@ -295,7 +298,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # the intermediate size for fc_{1,2} is halved when compared to LLaMAMLP, thus implementing GeGLU x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) - x = torch.nn.functional.gelu(x_fc_1, approximate="tanh") * x_fc_2 + # x = torch.nn.functional.gelu(x_fc_1, approximate="tanh") * x_fc_2 + # NOTE: in HF they don't use approximation + x = torch.nn.functional.gelu(x_fc_1) * x_fc_2 return self.proj(x) diff --git a/lit_gpt/rmsnorm.py b/lit_gpt/rmsnorm.py index 4f6d36a106..7de89630c6 100644 --- a/lit_gpt/rmsnorm.py +++ b/lit_gpt/rmsnorm.py @@ -10,19 +10,25 @@ class RMSNorm(torch.nn.Module): https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. """ - def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: + # TODO: make `add_unit_offset` be dependent by a config + # def __init__(self, size: int, dim: int = -1, eps: float = 1e-5, add_unit_offset: bool = True) -> None: + def __init__(self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = True) -> None: super().__init__() self.weight = torch.nn.Parameter(torch.ones(size)) self.eps = eps self.dim = dim + self.add_unit_offset = add_unit_offset + # NOTE: output now closer to the official gemma implementation + # https://github.com/google/gemma_pytorch/blob/ca890c7abaa41ce7ab0eeda9aa8a52c0796b3a16/gemma/model.py#L170-L179 def forward(self, x: torch.Tensor) -> torch.Tensor: dtype = x.dtype x = x.float() # NOTE: the original RMSNorm paper implementation is not equivalent norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) x_normed = x * torch.rsqrt(norm_x + self.eps) - return (self.weight * x_normed).to(dtype=dtype) + x_normed = x_normed.to(dtype=dtype) + return x_normed * (self.add_unit_offset + self.weight) def reset_parameters(self) -> None: torch.nn.init.ones_(self.weight) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index e93796b9d8..2fa74e0058 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -189,6 +189,7 @@ def copy_weights_hf_llama( del qkv_weights[i] +# TODO: probably we can simply reuse Llama weights copy def copy_weights_hf_gemma( config: Config, qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], @@ -234,13 +235,6 @@ def copy_weights_hf_gemma( else: to_name = weight_map[name] param = load_param(param, name, dtype) - - # halve the mlp linears: they are double the size in the HF implementation compared to Keras - halving_name_2_dim = {"mlp.fc_1": 0, "mlp.fc_2": 0, "mlp.proj": 1} - for key, dim in halving_name_2_dim.items(): - if key in to_name: - param, _ = torch.chunk(param, 2, dim=dim) - if saver is not None: param = saver.store_early(param) state_dict[to_name] = param diff --git a/tutorials/download_gemma.md b/tutorials/download_gemma.md index a344051809..363922e164 100644 --- a/tutorials/download_gemma.md +++ b/tutorials/download_gemma.md @@ -26,7 +26,7 @@ After access is granted, you can find your HF hub token in Date: Thu, 22 Feb 2024 11:36:08 +0000 Subject: [PATCH 07/34] Gemma-7b now works. --- lit_gpt/config.py | 7 +++++-- lit_gpt/model.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/lit_gpt/config.py b/lit_gpt/config.py index 4b2a96fed8..8769ebef78 100644 --- a/lit_gpt/config.py +++ b/lit_gpt/config.py @@ -23,6 +23,7 @@ class Config: padded_vocab_size: Optional[int] = None n_layer: int = 16 n_head: int = 32 + head_size: Optional[int] = None n_embd: int = 4096 rotary_percentage: float = 0.25 parallel_residual: bool = True @@ -64,8 +65,9 @@ def __post_init__(self): if not self.name: self.name = self.hf_config.get("name", self.name) - assert self.n_embd % self.n_head == 0 - self.head_size = self.n_embd // self.n_head + if self.head_size is None: + assert self.n_embd % self.n_head == 0 + self.head_size = self.n_embd // self.n_head # vocab size should be a power of 2 to be optimal on hardware. compute the closest value if self.padded_vocab_size is None: @@ -794,6 +796,7 @@ def norm_class(self) -> Type: n_embd=3072, n_layer=28, n_head=16, + head_size=256, rotary_percentage=1.0, parallel_residual=False, bias=False, diff --git a/lit_gpt/model.py b/lit_gpt/model.py index 197fb2cce1..2f4bd3b1be 100644 --- a/lit_gpt/model.py +++ b/lit_gpt/model.py @@ -177,7 +177,7 @@ def __init__(self, config: Config) -> None: # key, query, value projections for all heads, but in a batch self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) # output projection - self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) # disabled by default self.kv_cache: Optional[KVCache] = None @@ -227,7 +227,7 @@ def forward( y = self.scaled_dot_product_attention(q, k, v, mask) - y = y.reshape(B, T, self.config.n_embd) # re-assemble all head outputs side by side + y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side # output projection return self.proj(y) From 9a1f23c45a6b481c3764c9461cca699945c314e0 Mon Sep 17 00:00:00 2001 From: rasbt Date: Thu, 22 Feb 2024 13:58:51 +0000 Subject: [PATCH 08/34] add instruction-finetuned version --- lit_gpt/config.py | 34 ++++++++++++++++++++++++++++++++++ tutorials/download_gemma.md | 6 +++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/lit_gpt/config.py b/lit_gpt/config.py index 8769ebef78..a3d3b5b459 100644 --- a/lit_gpt/config.py +++ b/lit_gpt/config.py @@ -821,6 +821,40 @@ def norm_class(self) -> Type: _mlp_class="GemmaMLP", intermediate_size=16384, ), + # https://huggingface.co/google/gemma-7b-it/blob/main/config.json + dict( + name="Gemma-7b-it-hf", + hf_config=dict(org="google", name="gemma-7b-it"), + vocab_size=256000, + padding_multiple=64, + n_embd=3072, + n_layer=28, + n_head=16, + head_size=256, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="GemmaMLP", + intermediate_size=24576, + ), + # https://huggingface.co/google/gemma-2b-it/blob/main/config.json + dict( + name="Gemma-2b-it-hf", + hf_config=dict(org="google", name="gemma-2b-it"), + vocab_size=256000, + padding_multiple=64, + n_embd=2048, + n_layer=18, + n_head=8, + n_query_groups=1, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="GemmaMLP", + intermediate_size=16384, + ), ] configs.extend(gemma) diff --git a/tutorials/download_gemma.md b/tutorials/download_gemma.md index 363922e164..f740538389 100644 --- a/tutorials/download_gemma.md +++ b/tutorials/download_gemma.md @@ -14,10 +14,14 @@ python scripts/download.py | grep gemma which will print ```text -google/gemma-2b google/gemma-7b +google/gemma-2b +google/gemma-7b-it +google/gemma-2b-it ``` +In the list above, `gemma-2b` and `gemma-7b` are the pretrained models, and `gemma-2b-it` and `gemma-7b-it` are the instruction-finetuned models. + In order to use a specific checkpoint, for instance [gemma-2b](https://huggingface.co/google/gemma-2b), download the weights and convert the checkpoint to the lit-gpt format. This requires that you've been granted access to the weights on the HuggingFace hub. You can do so by following the steps at . From 3ec329f45e4fafadbead9a10365736dac3747ebe Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 22 Feb 2024 15:50:07 +0000 Subject: [PATCH 09/34] A test for config to check head_size --- tests/test_config.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_config.py b/tests/test_config.py index 0ab8016f9d..d1b7bd9865 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -127,3 +127,12 @@ def test_from_checkpoint(tmp_path): assert config.name == "pythia-14m" assert config.block_size == 24 assert config.n_layer == 2 + + +@pytest.mark.parametrize("head_size", [None, 128]) +def test_head_size(head_size): + from lit_gpt import Config + + config = Config(head_size) + + assert config.head_size == head_size or config.n_embd // config.n_head From 29a00c20b4d98cb7cf0166118c9d7a3ec20b9638 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 22 Feb 2024 16:09:33 +0000 Subject: [PATCH 10/34] Update Gemma config --- lit_gpt/config.py | 47 +++++++++-------------------------------------- 1 file changed, 9 insertions(+), 38 deletions(-) diff --git a/lit_gpt/config.py b/lit_gpt/config.py index a3d3b5b459..cfa07d2b6d 100644 --- a/lit_gpt/config.py +++ b/lit_gpt/config.py @@ -787,26 +787,9 @@ def norm_class(self) -> Type: # Google Gemma ############### gemma = [ - # https://huggingface.co/google/gemma-7b/blob/main/config.json - dict( - name="Gemma-7b-hf", - hf_config=dict(org="google", name="gemma-7b"), - vocab_size=256000, - padding_multiple=64, - n_embd=3072, - n_layer=28, - n_head=16, - head_size=256, - rotary_percentage=1.0, - parallel_residual=False, - bias=False, - _norm_class="RMSNorm", - _mlp_class="GemmaMLP", - intermediate_size=24576, - ), # https://huggingface.co/google/gemma-2b/blob/main/config.json dict( - name="Gemma-2b-hf", + name="Gemma-2b", hf_config=dict(org="google", name="gemma-2b"), vocab_size=256000, padding_multiple=64, @@ -821,10 +804,10 @@ def norm_class(self) -> Type: _mlp_class="GemmaMLP", intermediate_size=16384, ), - # https://huggingface.co/google/gemma-7b-it/blob/main/config.json + # https://huggingface.co/google/gemma-7b/blob/main/config.json dict( - name="Gemma-7b-it-hf", - hf_config=dict(org="google", name="gemma-7b-it"), + name="Gemma-7b", + hf_config=dict(org="google", name="gemma-7b"), vocab_size=256000, padding_multiple=64, n_embd=3072, @@ -838,25 +821,13 @@ def norm_class(self) -> Type: _mlp_class="GemmaMLP", intermediate_size=24576, ), - # https://huggingface.co/google/gemma-2b-it/blob/main/config.json - dict( - name="Gemma-2b-it-hf", - hf_config=dict(org="google", name="gemma-2b-it"), - vocab_size=256000, - padding_multiple=64, - n_embd=2048, - n_layer=18, - n_head=8, - n_query_groups=1, - rotary_percentage=1.0, - parallel_residual=False, - bias=False, - _norm_class="RMSNorm", - _mlp_class="GemmaMLP", - intermediate_size=16384, - ), ] configs.extend(gemma) +for c in gemma: + copy = deepcopy(c) + copy["name"] = f"{c['name']}-it" + copy["hf_config"]["name"] = f"{c['hf_config']['name']}-it" + configs.append(copy) ########################## From bf9d711d41aa4b1bff2e68c77be3dd909c08e7f6 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 22 Feb 2024 16:11:20 +0000 Subject: [PATCH 11/34] Adapter_v2 and LoRA: attn.proj size is head_size * num_heads --- lit_gpt/adapter_v2.py | 2 +- lit_gpt/lora.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lit_gpt/adapter_v2.py b/lit_gpt/adapter_v2.py index 7dcfc5b6b0..759b4128c8 100644 --- a/lit_gpt/adapter_v2.py +++ b/lit_gpt/adapter_v2.py @@ -123,7 +123,7 @@ def __init__(self, config: Config, block_idx: int) -> None: # key, query, value projections for all heads, but in a batch self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias) # output projection - self.proj = AdapterV2Linear(config.n_embd, config.n_embd, bias=config.bias) + self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) # disabled by default self.kv_cache: Optional[KVCache] = None diff --git a/lit_gpt/lora.py b/lit_gpt/lora.py index fed1a11cb1..fea0ba7c5a 100644 --- a/lit_gpt/lora.py +++ b/lit_gpt/lora.py @@ -599,7 +599,7 @@ def __init__(self, config: Config) -> None: ) # output projection self.proj = LoRALinear( - config.n_embd, + config.head_size * config.n_head, config.n_embd, bias=config.bias, r=(config.r if config.to_projection else 0), From 072c9f64af1ea18dfd828d0ea02dfa105abb547b Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 22 Feb 2024 16:12:12 +0000 Subject: [PATCH 12/34] Adapter_v2 and LoRA: gemmamlp class --- lit_gpt/adapter_v2.py | 8 ++++++++ lit_gpt/lora.py | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/lit_gpt/adapter_v2.py b/lit_gpt/adapter_v2.py index 759b4128c8..bf8781a835 100644 --- a/lit_gpt/adapter_v2.py +++ b/lit_gpt/adapter_v2.py @@ -194,6 +194,14 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) +class GemmaMLP(LLaMAMLP): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fc_1 = self.fc_1(x) + x_fc_2 = self.fc_2(x) + x = torch.nn.functional.gelu(x_fc_1) * x_fc_2 + return self.proj(x) + + class LLaMAMoE(lit_gpt.model.LLaMAMoE): def __init__(self, config: Config) -> None: nn.Module.__init__(self) diff --git a/lit_gpt/lora.py b/lit_gpt/lora.py index fea0ba7c5a..ad5746d2cc 100644 --- a/lit_gpt/lora.py +++ b/lit_gpt/lora.py @@ -699,6 +699,14 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) +class GemmaMLP(LLaMAMLP): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fc_1 = self.fc_1(x) + x_fc_2 = self.fc_2(x) + x = torch.nn.functional.gelu(x_fc_1) * x_fc_2 + return self.proj(x) + + class LLaMAMoE(lit_gpt.model.LLaMAMoE): def __init__(self, config: Config) -> None: nn.Module.__init__(self) From bed71f1de8c111267592200a0dd02c59a00ddbd1 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 22 Feb 2024 16:13:10 +0000 Subject: [PATCH 13/34] RMSNorm: unit offset is configurable --- lit_gpt/config.py | 7 ++++++- lit_gpt/rmsnorm.py | 6 +----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/lit_gpt/config.py b/lit_gpt/config.py index cfa07d2b6d..6e146c5c32 100644 --- a/lit_gpt/config.py +++ b/lit_gpt/config.py @@ -52,6 +52,7 @@ class Config: n_query_groups: Optional[int] = None shared_attention_norm: bool = False _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" + rmsnorm_add_unit_offset: bool = False norm_eps: float = 1e-5 _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP" gelu_approximate: str = "none" @@ -140,9 +141,11 @@ def mlp_class(self) -> Type: def norm_class(self) -> Type: # `self._norm_class` cannot be the type to keep the config json serializable if self._norm_class == "RMSNorm": + from functools import partial + from lit_gpt.rmsnorm import RMSNorm - return RMSNorm + return partial(RMSNorm, add_unit_offset=self.rmsnorm_add_unit_offset) return getattr(torch.nn, self._norm_class) @@ -801,6 +804,7 @@ def norm_class(self) -> Type: parallel_residual=False, bias=False, _norm_class="RMSNorm", + rmsnorm_add_unit_offset=True, _mlp_class="GemmaMLP", intermediate_size=16384, ), @@ -818,6 +822,7 @@ def norm_class(self) -> Type: parallel_residual=False, bias=False, _norm_class="RMSNorm", + rmsnorm_add_unit_offset=True, _mlp_class="GemmaMLP", intermediate_size=24576, ), diff --git a/lit_gpt/rmsnorm.py b/lit_gpt/rmsnorm.py index 7de89630c6..e353b36c81 100644 --- a/lit_gpt/rmsnorm.py +++ b/lit_gpt/rmsnorm.py @@ -10,17 +10,13 @@ class RMSNorm(torch.nn.Module): https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. """ - # TODO: make `add_unit_offset` be dependent by a config - # def __init__(self, size: int, dim: int = -1, eps: float = 1e-5, add_unit_offset: bool = True) -> None: - def __init__(self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = True) -> None: + def __init__(self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False) -> None: super().__init__() self.weight = torch.nn.Parameter(torch.ones(size)) self.eps = eps self.dim = dim self.add_unit_offset = add_unit_offset - # NOTE: output now closer to the official gemma implementation - # https://github.com/google/gemma_pytorch/blob/ca890c7abaa41ce7ab0eeda9aa8a52c0796b3a16/gemma/model.py#L170-L179 def forward(self, x: torch.Tensor) -> torch.Tensor: dtype = x.dtype x = x.float() From ec7d01e7c0af0a407f85ce54b703200611d96d7f Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 22 Feb 2024 16:13:46 +0000 Subject: [PATCH 14/34] Configurable wte output scaling --- lit_gpt/config.py | 5 ++++- lit_gpt/model.py | 7 ++----- lit_gpt/utils.py | 5 ++--- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/lit_gpt/config.py b/lit_gpt/config.py index 6e146c5c32..5070c98de2 100644 --- a/lit_gpt/config.py +++ b/lit_gpt/config.py @@ -17,6 +17,7 @@ class Config: name: str = "" hf_config: dict = field(default_factory=dict) + scale_wte_output: bool = False block_size: int = 4096 vocab_size: int = 50254 padding_multiple: int = 512 @@ -54,7 +55,7 @@ class Config: _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" rmsnorm_add_unit_offset: bool = False norm_eps: float = 1e-5 - _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP" + _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP" gelu_approximate: str = "none" intermediate_size: Optional[int] = None rope_condense_ratio: int = 1 @@ -794,6 +795,7 @@ def norm_class(self) -> Type: dict( name="Gemma-2b", hf_config=dict(org="google", name="gemma-2b"), + scale_wte_output=True, vocab_size=256000, padding_multiple=64, n_embd=2048, @@ -812,6 +814,7 @@ def norm_class(self) -> Type: dict( name="Gemma-7b", hf_config=dict(org="google", name="gemma-7b"), + scale_wte_output=True, vocab_size=256000, padding_multiple=64, n_embd=3072, diff --git a/lit_gpt/model.py b/lit_gpt/model.py index 2f4bd3b1be..a2145f9d45 100644 --- a/lit_gpt/model.py +++ b/lit_gpt/model.py @@ -87,8 +87,8 @@ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) - mask = None x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - # NOTE: this is a secret sauce (Gemma) - x = x * (self.config.n_embd**0.5) + if self.config.scale_wte_output: + x = x * (self.config.n_embd**0.5) for block in self.transformer.h: x = block(x, cos, sin, mask, input_pos) @@ -295,11 +295,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class GemmaMLP(LLaMAMLP): def forward(self, x: torch.Tensor) -> torch.Tensor: - # the intermediate size for fc_{1,2} is halved when compared to LLaMAMLP, thus implementing GeGLU x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) - # x = torch.nn.functional.gelu(x_fc_1, approximate="tanh") * x_fc_2 - # NOTE: in HF they don't use approximation x = torch.nn.functional.gelu(x_fc_1) * x_fc_2 return self.proj(x) diff --git a/lit_gpt/utils.py b/lit_gpt/utils.py index b2977bd16f..fd03b6f4c5 100644 --- a/lit_gpt/utils.py +++ b/lit_gpt/utils.py @@ -45,9 +45,8 @@ def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: files = { "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), - "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or ( - checkpoint_dir / "tokenizer.model" - ).is_file(), + "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() + or (checkpoint_dir / "tokenizer.model").is_file(), "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(), } if checkpoint_dir.is_dir(): From bd0864c32c0accfce13dec0e0b447874a0a97ab2 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 22 Feb 2024 16:14:14 +0000 Subject: [PATCH 15/34] Update tests to supports changes in Config class --- tests/test_generate_adapter.py | 2 +- tests/test_generate_lora.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_generate_adapter.py b/tests/test_generate_adapter.py index b39828d123..d296a68613 100644 --- a/tests/test_generate_adapter.py +++ b/tests/test_generate_adapter.py @@ -44,7 +44,7 @@ def test_main(fake_checkpoint_dir, monkeypatch, version, tensor_like): # only the generated result is printed to stdout assert out.getvalue() == "foo bar baz\n" * num_samples - assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4, 'n_embd': 8" in err.getvalue() + assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4, 'head_size': 2, 'n_embd': 8" in err.getvalue() @pytest.mark.parametrize("version", ("", "_v2")) diff --git a/tests/test_generate_lora.py b/tests/test_generate_lora.py index 3984eb5a52..4b511008ae 100644 --- a/tests/test_generate_lora.py +++ b/tests/test_generate_lora.py @@ -49,7 +49,7 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like): # only the generated result is printed to stdout assert out.getvalue() == "foo bar baz\n" * num_samples - assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4, 'n_embd': 8" in err.getvalue() + assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4, 'head_size': 2, 'n_embd': 8" in err.getvalue() def test_lora_variables_exist(): From bf8c9b5260c2c5467d63a7b5532a748d82bad6b8 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 22 Feb 2024 17:08:22 +0000 Subject: [PATCH 16/34] Test for Gemma --- tests/test_model.py | 60 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/test_model.py b/tests/test_model.py index ed6ad82968..d9c65bdbe9 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -551,6 +551,66 @@ def test_against_original_stablelm_zephyr_3b(device, dtype): torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"]) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_original_gemma(model_name, device, dtype): + from transformers.models.gemma.configuration_gemma import GemmaConfig + from transformers.models.gemma.modeling_gemma import GemmaForCausalLM + + from lit_gpt import GPT, Config + from scripts.convert_hf_checkpoint import copy_weights_hf_gemma + + torch.set_default_dtype(dtype) + + T = 5 + ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86) + theirs_config = GemmaConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + head_dim=ours_config.head_size, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.bias, + tie_word_embeddings=True, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = GemmaForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_gemma(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + @RunIf(dynamo=True) @torch.inference_mode() def test_model_compile(): From c78bd6ed83ffe3b3cbdbe4b1fe04d4e06b0a23d0 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 22 Feb 2024 17:29:39 +0000 Subject: [PATCH 17/34] conver_hf: reuse llama copy function --- scripts/convert_hf_checkpoint.py | 133 +++++++++---------------------- 1 file changed, 36 insertions(+), 97 deletions(-) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 2fa74e0058..98fd273e7b 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -82,17 +82,21 @@ def copy_weights_falcon( } # the original model definition is different for each size if "7b" in model_name: - weight_map.update({ - "transformer.h.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", - "transformer.h.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", - }) + weight_map.update( + { + "transformer.h.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", + "transformer.h.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", + } + ) elif "40b" in model_name or "180B" in model_name: - weight_map.update({ - "transformer.h.{}.ln_attn.bias": "transformer.h.{}.norm_1.bias", - "transformer.h.{}.ln_attn.weight": "transformer.h.{}.norm_1.weight", - "transformer.h.{}.ln_mlp.bias": "transformer.h.{}.norm_2.bias", - "transformer.h.{}.ln_mlp.weight": "transformer.h.{}.norm_2.weight", - }) + weight_map.update( + { + "transformer.h.{}.ln_attn.bias": "transformer.h.{}.norm_1.bias", + "transformer.h.{}.ln_attn.weight": "transformer.h.{}.norm_1.weight", + "transformer.h.{}.ln_mlp.bias": "transformer.h.{}.norm_2.bias", + "transformer.h.{}.ln_mlp.weight": "transformer.h.{}.norm_2.weight", + } + ) else: raise NotImplementedError @@ -113,6 +117,7 @@ def copy_weights_hf_llama( qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + tie_weights: bool = False, saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, ) -> None: @@ -132,18 +137,22 @@ def copy_weights_hf_llama( "lm_head.weight": "lm_head.weight", } if config._mlp_class == "LLaMAMoE": - weight_map.update({ - "model.layers.{}.block_sparse_moe.gate.weight": "transformer.h.{l}.mlp.gate.weight", - "model.layers.{}.block_sparse_moe.experts.{}.w1.weight": "transformer.h.{l}.mlp.experts.{e}.fc_1.weight", - "model.layers.{}.block_sparse_moe.experts.{}.w3.weight": "transformer.h.{l}.mlp.experts.{e}.fc_2.weight", - "model.layers.{}.block_sparse_moe.experts.{}.w2.weight": "transformer.h.{l}.mlp.experts.{e}.proj.weight", - }) - elif config._mlp_class == "LLaMAMLP": - weight_map.update({ - "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{l}.mlp.fc_1.weight", - "model.layers.{}.mlp.up_proj.weight": "transformer.h.{l}.mlp.fc_2.weight", - "model.layers.{}.mlp.down_proj.weight": "transformer.h.{l}.mlp.proj.weight", - }) + weight_map.update( + { + "model.layers.{}.block_sparse_moe.gate.weight": "transformer.h.{l}.mlp.gate.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w1.weight": "transformer.h.{l}.mlp.experts.{e}.fc_1.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w3.weight": "transformer.h.{l}.mlp.experts.{e}.fc_2.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w2.weight": "transformer.h.{l}.mlp.experts.{e}.proj.weight", + } + ) + elif config._mlp_class in ("LLaMAMLP", "GemmaMLP"): + weight_map.update( + { + "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{l}.mlp.fc_1.weight", + "model.layers.{}.mlp.up_proj.weight": "transformer.h.{l}.mlp.fc_2.weight", + "model.layers.{}.mlp.down_proj.weight": "transformer.h.{l}.mlp.proj.weight", + } + ) else: raise NotImplementedError @@ -171,76 +180,8 @@ def copy_weights_hf_llama( param = saver.store_early(param) state_dict[to_name] = param - # convert separate q, k, v matrices into an interleaved qkv - for i, (q, k, v) in list(qkv_weights.items()): - if q is None or k is None or v is None: - # split across different .bin files - continue - q = load_param(q, f"layer {i} q", dtype) - k = load_param(k, f"layer {i} k", dtype) - v = load_param(v, f"layer {i} v", dtype) - q_per_kv = config.n_head // config.n_query_groups - qs = torch.split(q, config.head_size * q_per_kv) - ks = torch.split(k, config.head_size) - vs = torch.split(v, config.head_size) - cycled = [t for group in zip(qs, ks, vs) for t in group] - qkv = torch.cat(cycled) - state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv - del qkv_weights[i] - - -# TODO: probably we can simply reuse Llama weights copy -def copy_weights_hf_gemma( - config: Config, - qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], - state_dict: Dict[str, torch.Tensor], - hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], - saver: Optional[incremental_save] = None, - dtype: Optional[torch.dtype] = None, -) -> None: - weight_map = { - "model.embed_tokens.weight": "transformer.wte.weight", - "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", - "model.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", - "model.layers.{}.self_attn.q_proj.weight": None, - "model.layers.{}.self_attn.k_proj.weight": None, - "model.layers.{}.self_attn.v_proj.weight": None, - "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", - "model.layers.{}.self_attn.rotary_emb.inv_freq": None, - "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", - "model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias", - "model.norm.weight": "transformer.ln_f.weight", - "model.norm.bias": "transformer.ln_f.bias", - "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", - "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", - "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", - } - if config.bias: - raise NotImplementedError("bias halving not implemented") - - for name, param in hf_weights.items(): - if "model.layers" in name: - from_name, l = layer_template(name, 2) - qkv = qkv_weights.setdefault(l, [None, None, None]) - if "q_proj" in name: - qkv[0] = param - elif "k_proj" in name: - qkv[1] = param - elif "v_proj" in name: - qkv[2] = param - to_name = weight_map[from_name] - if to_name is None: - continue - to_name = to_name.format(l) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype) - if saver is not None: - param = saver.store_early(param) - state_dict[to_name] = param - - # weight tying - state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] + if tie_weights: + state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] # convert separate q, k, v matrices into an interleaved qkv for i, (q, k, v) in list(qkv_weights.items()): @@ -371,16 +312,14 @@ def convert_hf_checkpoint( if "falcon" in model_name: copy_fn = partial(copy_weights_falcon, model_name) - elif config._mlp_class in ("LLaMAMLP", "LLaMAMoE"): + elif config._mlp_class in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): # holder to reconstitute the split q, k, v qkv_weights = {} - copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) + tie_weights = "Gemma" in config.name + copy_fn = partial(copy_weights_hf_llama, config, qkv_weights, tie_weights=tie_weights) elif "phi" in model_name: qkv_weights = {} copy_fn = partial(copy_weights_phi, config, qkv_weights) - elif config._mlp_class == "GemmaMLP": - qkv_weights = {} - copy_fn = partial(copy_weights_hf_gemma, config, qkv_weights) else: copy_fn = copy_weights_gpt_neox From 50ad50995bc7782edd15c0939e7531d3234813f4 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 22 Feb 2024 18:46:17 +0000 Subject: [PATCH 18/34] Test Gemma model: use llama weights copying --- tests/test_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index d9c65bdbe9..8e294d36bb 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -574,7 +574,7 @@ def test_against_original_gemma(model_name, device, dtype): from transformers.models.gemma.modeling_gemma import GemmaForCausalLM from lit_gpt import GPT, Config - from scripts.convert_hf_checkpoint import copy_weights_hf_gemma + from scripts.convert_hf_checkpoint import copy_weights_hf_llama torch.set_default_dtype(dtype) @@ -598,8 +598,10 @@ def test_against_original_gemma(model_name, device, dtype): theirs_model = GemmaForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() + # Gemma weights are shipped without `lm_head.weight` + theirs_state_dict.pop("lm_head.weight") state_dict = {} - copy_weights_hf_gemma(ours_config, {}, state_dict, theirs_state_dict) + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict, tie_weights=True) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) From d98df3bdec077c38a208f62a41a65c77dc112815 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 22 Feb 2024 19:34:42 +0000 Subject: [PATCH 19/34] Update convert_lit + test --- scripts/convert_lit_checkpoint.py | 60 ++++++++------ tests/test_convert_lit_checkpoint.py | 115 +++++++++++++++++++++------ 2 files changed, 127 insertions(+), 48 deletions(-) diff --git a/scripts/convert_lit_checkpoint.py b/scripts/convert_lit_checkpoint.py index 8ad27edc9c..67d55a504f 100644 --- a/scripts/convert_lit_checkpoint.py +++ b/scripts/convert_lit_checkpoint.py @@ -36,17 +36,21 @@ def copy_weights_falcon( } # the original model definition is different for each size if "7b" in model_name: - weight_map.update({ - "transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias", - "transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight", - }) + weight_map.update( + { + "transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias", + "transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight", + } + ) elif "40b" in model_name or "180B" in model_name: - weight_map.update({ - "transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias", - "transformer.h.{}.norm_1.weight": "transformer.h.{}.ln_attn.weight", - "transformer.h.{}.norm_2.bias": "transformer.h.{}.ln_mlp.bias", - "transformer.h.{}.norm_2.weight": "transformer.h.{}.ln_mlp.weight", - }) + weight_map.update( + { + "transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias", + "transformer.h.{}.norm_1.weight": "transformer.h.{}.ln_attn.weight", + "transformer.h.{}.norm_2.bias": "transformer.h.{}.ln_mlp.bias", + "transformer.h.{}.norm_2.weight": "transformer.h.{}.ln_mlp.weight", + } + ) else: raise NotImplementedError @@ -102,6 +106,7 @@ def copy_weights_llama( config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + untie_weights: bool = False, saver: Optional[incremental_save] = None, ) -> None: weight_map = { @@ -116,22 +121,28 @@ def copy_weights_llama( "lm_head.weight": "lm_head.weight", } if config._mlp_class == "LLaMAMoE": - weight_map.update({ - "transformer.h.{}.mlp.gate.weight": "model.layers.{l}.block_sparse_moe.gate.weight", - "transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w1.weight", - "transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w3.weight", - "transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w2.weight", - }) - elif config._mlp_class == "LLaMAMLP": - weight_map.update({ - "transformer.h.{}.mlp.fc_1.weight": "model.layers.{l}.mlp.gate_proj.weight", - "transformer.h.{}.mlp.fc_2.weight": "model.layers.{l}.mlp.up_proj.weight", - "transformer.h.{}.mlp.proj.weight": "model.layers.{l}.mlp.down_proj.weight", - }) + weight_map.update( + { + "transformer.h.{}.mlp.gate.weight": "model.layers.{l}.block_sparse_moe.gate.weight", + "transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w1.weight", + "transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w3.weight", + "transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w2.weight", + } + ) + elif config._mlp_class in ("LLaMAMLP", "GemmaMLP"): + weight_map.update( + { + "transformer.h.{}.mlp.fc_1.weight": "model.layers.{l}.mlp.gate_proj.weight", + "transformer.h.{}.mlp.fc_2.weight": "model.layers.{l}.mlp.up_proj.weight", + "transformer.h.{}.mlp.proj.weight": "model.layers.{l}.mlp.down_proj.weight", + } + ) else: raise NotImplementedError for name, param in lit_weights.items(): + if name == "lm_head.weight" and untie_weights: + continue if name.endswith(".attn.attn.weight"): from_name, l = layer_template(name, 2) q = "model.layers.{}.self_attn.q_proj.weight".format(l) @@ -238,8 +249,9 @@ def convert_lit_checkpoint(checkpoint_path: Path, output_path: Path, config_path if "falcon" in config.name: copy_fn = partial(copy_weights_falcon, config.name) - elif config._mlp_class in ("LLaMAMLP", "LLaMAMoE"): - copy_fn = partial(copy_weights_llama, config) + elif config._mlp_class in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): + untie_weights = "Gemma" in config.name + copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights) elif "phi" in config.name: copy_fn = partial(copy_weights_phi, config) else: diff --git a/tests/test_convert_lit_checkpoint.py b/tests/test_convert_lit_checkpoint.py index c0d40e3d79..7e0fca8c63 100644 --- a/tests/test_convert_lit_checkpoint.py +++ b/tests/test_convert_lit_checkpoint.py @@ -9,6 +9,7 @@ import pytest import torch +from conftest import RunIf wd = Path(__file__).parent.parent.absolute() @@ -371,6 +372,68 @@ def test_against_original_stablelm_zephyr_3b(): torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"]) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_original_gemma(model_name, device, dtype): + from transformers.models.gemma.configuration_gemma import GemmaConfig + from transformers.models.gemma.modeling_gemma import GemmaForCausalLM + + from lit_gpt import GPT, Config + from scripts.convert_lit_checkpoint import copy_weights_llama + + torch.set_default_dtype(dtype) + + T = 5 + ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86) + theirs_config = GemmaConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + head_dim=ours_config.head_size, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.bias, + tie_word_embeddings=True, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + ours_model = GPT(ours_config).to(device) + # tie weights + ours_model.lm_head.weight = ours_model.transformer.wte.weight + ours_state_dict = ours_model.state_dict() + theirs_state_dict = {} + copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict, untie_weights=True) + theirs_model = GemmaForCausalLM(theirs_config).to(device) + theirs_model.load_state_dict(theirs_state_dict, strict=False) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + def test_check_conversion_supported_adapter(): from scripts.convert_lit_checkpoint import check_conversion_supported @@ -397,20 +460,22 @@ def test_qkv_split(): # MHA config = Config(n_embd=4, n_head=4) - qkv = torch.tensor([ - [0, 1, 2, 3], - [4, 5, 6, 7], - [8, 9, 10, 11], - [12, 13, 14, 15], - [16, 17, 18, 19], - [20, 21, 22, 23], - [24, 25, 26, 27], - [28, 29, 30, 31], - [32, 33, 34, 35], - [36, 37, 38, 39], - [40, 41, 42, 43], - [44, 45, 46, 47], - ]) + qkv = torch.tensor( + [ + [0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11], + [12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23], + [24, 25, 26, 27], + [28, 29, 30, 31], + [32, 33, 34, 35], + [36, 37, 38, 39], + [40, 41, 42, 43], + [44, 45, 46, 47], + ] + ) q, k, v = qkv_split(qkv, config) torch.testing.assert_close(q, torch.tensor([[0, 1, 2, 3], [12, 13, 14, 15], [24, 25, 26, 27], [36, 37, 38, 39]])) torch.testing.assert_close(k, torch.tensor([[4, 5, 6, 7], [16, 17, 18, 19], [28, 29, 30, 31], [40, 41, 42, 43]])) @@ -418,16 +483,18 @@ def test_qkv_split(): # GQA config = Config(n_embd=4, n_head=4, n_query_groups=2) - qkv = torch.tensor([ - [0, 1, 2, 3], - [4, 5, 6, 7], - [8, 9, 10, 11], - [12, 13, 14, 15], - [16, 17, 18, 19], - [20, 21, 22, 23], - [24, 25, 26, 27], - [28, 29, 30, 31], - ]) + qkv = torch.tensor( + [ + [0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11], + [12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23], + [24, 25, 26, 27], + [28, 29, 30, 31], + ] + ) q, k, v = qkv_split(qkv, config) torch.testing.assert_close(q, torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [16, 17, 18, 19], [20, 21, 22, 23]])) torch.testing.assert_close(k, torch.tensor([[8, 9, 10, 11], [24, 25, 26, 27]])) From 628c7bcab3c54c85902b5c91f68b1114cad4ff44 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 22 Feb 2024 19:53:12 +0000 Subject: [PATCH 20/34] Restore accidently deleted comment line --- scripts/convert_hf_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 98fd273e7b..3a50c4b300 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -318,6 +318,7 @@ def convert_hf_checkpoint( tie_weights = "Gemma" in config.name copy_fn = partial(copy_weights_hf_llama, config, qkv_weights, tie_weights=tie_weights) elif "phi" in model_name: + # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_phi, config, qkv_weights) else: From 1a2f9f8afbb07b4b860f5e3490415ba6e5e839e5 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 22 Feb 2024 20:37:39 +0000 Subject: [PATCH 21/34] Prompt for Gemma it (instruct models) --- chat/base.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/chat/base.py b/chat/base.py index 26963dd813..c6431e97d8 100644 --- a/chat/base.py +++ b/chat/base.py @@ -361,6 +361,15 @@ def prompt_config(checkpoint_dir: Path, tokenizer: Tokenizer) -> Tuple[str, Tupl stop_tokens = ([tokenizer.eos_id],) return system_prompt, stop_tokens + if re.search(r"gemma.*-it", checkpoint_name): + system_prompt = ( + "user\n" + "{prompt}\n" + "model\n" + ) + stop_tokens = ([tokenizer.eos_id],) + return system_prompt, stop_tokens + # default format return "{prompt}", ([tokenizer.eos_id],) From d002695ba07a3c79857185a247ad24b8096116c0 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Fri, 23 Feb 2024 13:30:57 +0300 Subject: [PATCH 22/34] RMSNorm: reduce computations when self.add_unit_offset is False --- chat/base.py | 6 +----- lit_gpt/rmsnorm.py | 4 +++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/chat/base.py b/chat/base.py index c6431e97d8..0f12958cfb 100644 --- a/chat/base.py +++ b/chat/base.py @@ -362,11 +362,7 @@ def prompt_config(checkpoint_dir: Path, tokenizer: Tokenizer) -> Tuple[str, Tupl return system_prompt, stop_tokens if re.search(r"gemma.*-it", checkpoint_name): - system_prompt = ( - "user\n" - "{prompt}\n" - "model\n" - ) + system_prompt = "user\n{prompt}\nmodel\n" stop_tokens = ([tokenizer.eos_id],) return system_prompt, stop_tokens diff --git a/lit_gpt/rmsnorm.py b/lit_gpt/rmsnorm.py index e353b36c81..49bc09d823 100644 --- a/lit_gpt/rmsnorm.py +++ b/lit_gpt/rmsnorm.py @@ -24,7 +24,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) x_normed = x * torch.rsqrt(norm_x + self.eps) x_normed = x_normed.to(dtype=dtype) - return x_normed * (self.add_unit_offset + self.weight) + if self.add_unit_offset: + return x_normed * (1 + self.weight) + return x_normed * self.weight def reset_parameters(self) -> None: torch.nn.init.ones_(self.weight) From c915d574f9f77c02bd2f808da17cb28f56148c15 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Fri, 23 Feb 2024 13:33:04 +0300 Subject: [PATCH 23/34] Auto markdown formatting --- tutorials/download_function_calling_llama_2.md | 2 +- tutorials/inference.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tutorials/download_function_calling_llama_2.md b/tutorials/download_function_calling_llama_2.md index 2c5d497fe7..f86799e240 100644 --- a/tutorials/download_function_calling_llama_2.md +++ b/tutorials/download_function_calling_llama_2.md @@ -29,4 +29,4 @@ Is strongly recommended to visit the model [repository](https://huggingface.co/T The chat script has a generic use case with a single function defined, feel free to play with it to fit your needs, for instance to make HTTP requests with the model outputs. -Have fun! \ No newline at end of file +Have fun! diff --git a/tutorials/inference.md b/tutorials/inference.md index c3b9c313f7..3d1017303b 100644 --- a/tutorials/inference.md +++ b/tutorials/inference.md @@ -40,7 +40,7 @@ We offer two scripts to leverage multiple devices for inference. Allows you to run models that wouldn't fit in a single card by partitioning the transformer blocks across all your devices and running them sequentially. For instance, `meta-llama/Llama-2-70b-chat-hf` would require ~140 GB of GPU memory to load on a single device, plus the memory for activations. -With 80 transformer layers, we could partition them across 8, 5, 4, or 2 devices. +With 80 transformer layers, we could partition them across 8, 5, 4, or 2 devices. ```shell python generate/sequentially.py \ From 32260e69b8a05c5d3bb3bff5ba687efa9febc9a7 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Fri, 23 Feb 2024 20:18:59 +0300 Subject: [PATCH 24/34] Drop `tie_weights` in convert_hf --- scripts/convert_hf_checkpoint.py | 6 ++---- tests/test_model.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 3a50c4b300..3839a8796d 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -117,7 +117,6 @@ def copy_weights_hf_llama( qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], - tie_weights: bool = False, saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, ) -> None: @@ -180,7 +179,7 @@ def copy_weights_hf_llama( param = saver.store_early(param) state_dict[to_name] = param - if tie_weights: + if "lm_head.weight" not in state_dict: state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] # convert separate q, k, v matrices into an interleaved qkv @@ -315,8 +314,7 @@ def convert_hf_checkpoint( elif config._mlp_class in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): # holder to reconstitute the split q, k, v qkv_weights = {} - tie_weights = "Gemma" in config.name - copy_fn = partial(copy_weights_hf_llama, config, qkv_weights, tie_weights=tie_weights) + copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) elif "phi" in model_name: # holder to reconstitute the split q, k, v qkv_weights = {} diff --git a/tests/test_model.py b/tests/test_model.py index 8e294d36bb..ca85d89d52 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -601,7 +601,7 @@ def test_against_original_gemma(model_name, device, dtype): # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} - copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict, tie_weights=True) + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) From 78ad64345a4eb8f9b72632aa11afc787836a4de2 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Fri, 23 Feb 2024 20:36:23 +0300 Subject: [PATCH 25/34] Comment explaining why head_size*num_head instead of n_embd --- lit_gpt/adapter_v2.py | 1 + lit_gpt/lora.py | 1 + lit_gpt/model.py | 1 + 3 files changed, 3 insertions(+) diff --git a/lit_gpt/adapter_v2.py b/lit_gpt/adapter_v2.py index bf8781a835..1a0af39ea2 100644 --- a/lit_gpt/adapter_v2.py +++ b/lit_gpt/adapter_v2.py @@ -123,6 +123,7 @@ def __init__(self, config: Config, block_idx: int) -> None: # key, query, value projections for all heads, but in a batch self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias) # output projection + # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size` * `n_head` self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) # disabled by default self.kv_cache: Optional[KVCache] = None diff --git a/lit_gpt/lora.py b/lit_gpt/lora.py index ad5746d2cc..1ec32bc10e 100644 --- a/lit_gpt/lora.py +++ b/lit_gpt/lora.py @@ -598,6 +598,7 @@ def __init__(self, config: Config) -> None: n_query_groups=config.n_query_groups, ) # output projection + # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size` * `n_head` self.proj = LoRALinear( config.head_size * config.n_head, config.n_embd, diff --git a/lit_gpt/model.py b/lit_gpt/model.py index a2145f9d45..d5cab83893 100644 --- a/lit_gpt/model.py +++ b/lit_gpt/model.py @@ -177,6 +177,7 @@ def __init__(self, config: Config) -> None: # key, query, value projections for all heads, but in a batch self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) # output projection + # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size` * `n_head` self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) # disabled by default self.kv_cache: Optional[KVCache] = None From 6f154abe10c3249b7b394629aba0d9d048cbb7a7 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Fri, 23 Feb 2024 20:38:27 +0300 Subject: [PATCH 26/34] scale_wte_output --> scale_embeddings --- lit_gpt/config.py | 6 +++--- lit_gpt/model.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lit_gpt/config.py b/lit_gpt/config.py index 5070c98de2..9c2ef68a04 100644 --- a/lit_gpt/config.py +++ b/lit_gpt/config.py @@ -17,7 +17,7 @@ class Config: name: str = "" hf_config: dict = field(default_factory=dict) - scale_wte_output: bool = False + scale_embeddings: bool = False block_size: int = 4096 vocab_size: int = 50254 padding_multiple: int = 512 @@ -795,7 +795,7 @@ def norm_class(self) -> Type: dict( name="Gemma-2b", hf_config=dict(org="google", name="gemma-2b"), - scale_wte_output=True, + scale_embeddings=True, vocab_size=256000, padding_multiple=64, n_embd=2048, @@ -814,7 +814,7 @@ def norm_class(self) -> Type: dict( name="Gemma-7b", hf_config=dict(org="google", name="gemma-7b"), - scale_wte_output=True, + scale_embeddings=True, vocab_size=256000, padding_multiple=64, n_embd=3072, diff --git a/lit_gpt/model.py b/lit_gpt/model.py index d5cab83893..81ff1bf8c3 100644 --- a/lit_gpt/model.py +++ b/lit_gpt/model.py @@ -87,7 +87,7 @@ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) - mask = None x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - if self.config.scale_wte_output: + if self.config.scale_embeddings: x = x * (self.config.n_embd**0.5) for block in self.transformer.h: From cfe68bb7f3cd048f8258a9a6a4b37b15a41c80e6 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Fri, 23 Feb 2024 20:41:12 +0300 Subject: [PATCH 27/34] Config: drop `self.rmsnorm_add_unit_offset` --- lit_gpt/config.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lit_gpt/config.py b/lit_gpt/config.py index 9c2ef68a04..4c73dc6beb 100644 --- a/lit_gpt/config.py +++ b/lit_gpt/config.py @@ -53,7 +53,6 @@ class Config: n_query_groups: Optional[int] = None shared_attention_norm: bool = False _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" - rmsnorm_add_unit_offset: bool = False norm_eps: float = 1e-5 _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP" gelu_approximate: str = "none" @@ -146,7 +145,7 @@ def norm_class(self) -> Type: from lit_gpt.rmsnorm import RMSNorm - return partial(RMSNorm, add_unit_offset=self.rmsnorm_add_unit_offset) + return partial(RMSNorm, add_unit_offset="Gemma" in self.name) return getattr(torch.nn, self._norm_class) @@ -806,7 +805,6 @@ def norm_class(self) -> Type: parallel_residual=False, bias=False, _norm_class="RMSNorm", - rmsnorm_add_unit_offset=True, _mlp_class="GemmaMLP", intermediate_size=16384, ), @@ -825,7 +823,6 @@ def norm_class(self) -> Type: parallel_residual=False, bias=False, _norm_class="RMSNorm", - rmsnorm_add_unit_offset=True, _mlp_class="GemmaMLP", intermediate_size=24576, ), From 57db71064a4d156990fda03c947d0b2fb353a5cf Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Fri, 23 Feb 2024 20:44:01 +0300 Subject: [PATCH 28/34] Comment why do we need a unit offset in RMSNorm --- lit_gpt/rmsnorm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lit_gpt/rmsnorm.py b/lit_gpt/rmsnorm.py index 49bc09d823..dcaab677a8 100644 --- a/lit_gpt/rmsnorm.py +++ b/lit_gpt/rmsnorm.py @@ -25,6 +25,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_normed = x * torch.rsqrt(norm_x + self.eps) x_normed = x_normed.to(dtype=dtype) if self.add_unit_offset: + # Gemma model requires a unit offset + # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176 return x_normed * (1 + self.weight) return x_normed * self.weight From 4c44085f78e3775c3886ab6b2d5f295bec52f925 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Fri, 23 Feb 2024 20:45:19 +0300 Subject: [PATCH 29/34] Bump up min version of transformers in github CI --- .github/workflows/cpu-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cpu-tests.yml b/.github/workflows/cpu-tests.yml index 9c1522ce47..4619928dbf 100644 --- a/.github/workflows/cpu-tests.yml +++ b/.github/workflows/cpu-tests.yml @@ -61,7 +61,7 @@ jobs: - name: Install all dependencies run: | - pip install -r requirements-all.txt pytest pytest-rerunfailures pytest-timeout transformers>=4.36.0 einops protobuf + pip install -r requirements-all.txt pytest pytest-rerunfailures pytest-timeout transformers>=4.38.0 einops protobuf pip list - name: Run tests without the package installed From e9b0c5a0cf05f7e83730a33b04b46863b418ddd7 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Fri, 23 Feb 2024 21:42:32 +0300 Subject: [PATCH 30/34] Update convert_hf test --- tests/test_convert_hf_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_convert_hf_checkpoint.py b/tests/test_convert_hf_checkpoint.py index c32f992d42..97d42b4f9e 100644 --- a/tests/test_convert_hf_checkpoint.py +++ b/tests/test_convert_hf_checkpoint.py @@ -97,6 +97,7 @@ def test_llama2_70b_conversion(): "transformer.h.5.attn.proj.weight": (8192, 8192), "transformer.h.5.mlp.fc_1.weight": (28672, 8192), "transformer.wte.weight": (32000, 8192), + "lm_head.weight": (32000, 8192), # due to weight tying lm_head is in the converted weights } From d17bb34dce83356f5b4bac20ece7107f717f837f Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com> Date: Fri, 23 Feb 2024 21:51:39 +0300 Subject: [PATCH 31/34] Update lit_gpt/adapter_v2.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- lit_gpt/adapter_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lit_gpt/adapter_v2.py b/lit_gpt/adapter_v2.py index 1a0af39ea2..51b826a224 100644 --- a/lit_gpt/adapter_v2.py +++ b/lit_gpt/adapter_v2.py @@ -123,7 +123,7 @@ def __init__(self, config: Config, block_idx: int) -> None: # key, query, value projections for all heads, but in a batch self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias) # output projection - # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size` * `n_head` + # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) # disabled by default self.kv_cache: Optional[KVCache] = None From 86b5b7abf7d10cef846775753e1e1eba5d164c42 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com> Date: Fri, 23 Feb 2024 21:51:48 +0300 Subject: [PATCH 32/34] Update lit_gpt/lora.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- lit_gpt/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lit_gpt/lora.py b/lit_gpt/lora.py index 1ec32bc10e..bfc7adc122 100644 --- a/lit_gpt/lora.py +++ b/lit_gpt/lora.py @@ -598,7 +598,7 @@ def __init__(self, config: Config) -> None: n_query_groups=config.n_query_groups, ) # output projection - # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size` * `n_head` + # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` self.proj = LoRALinear( config.head_size * config.n_head, config.n_embd, From 8854d146c21889d5f6e084321aac60662c73ffcd Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com> Date: Fri, 23 Feb 2024 21:51:56 +0300 Subject: [PATCH 33/34] Update lit_gpt/model.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- lit_gpt/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lit_gpt/model.py b/lit_gpt/model.py index 81ff1bf8c3..ed33664fa2 100644 --- a/lit_gpt/model.py +++ b/lit_gpt/model.py @@ -177,7 +177,7 @@ def __init__(self, config: Config) -> None: # key, query, value projections for all heads, but in a batch self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) # output projection - # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size` * `n_head` + # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) # disabled by default self.kv_cache: Optional[KVCache] = None From d219965dfe50ce4ed1ee698ca89ae706e6943315 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Fri, 23 Feb 2024 21:52:28 +0300 Subject: [PATCH 34/34] Bump up min transformers version in Azure workflow --- .github/azure-gpu-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/azure-gpu-test.yml b/.github/azure-gpu-test.yml index ddfd83c0e2..71e0aecfd2 100644 --- a/.github/azure-gpu-test.yml +++ b/.github/azure-gpu-test.yml @@ -39,7 +39,7 @@ jobs: displayName: "Image info & NVIDIA" - script: | - pip install -r requirements-all.txt pytest pytest-rerunfailures transformers>=4.36.0 einops protobuf + pip install -r requirements-all.txt pytest pytest-rerunfailures transformers>=4.38.0 einops protobuf displayName: 'Install dependencies' - bash: |