diff --git a/README.md b/README.md index c270bddf8e..130950c72f 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,7 @@ Every model is written from scratch to maximize performance and remove layers of | Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | | Llama 3.1 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 3.2 | 1B, 3B | Meta AI | [Meta AI 2024](https://ai.meta.com/blog/llama-3-2-connect-2024-vision-edge-mobile-devices/) | +| Llama 3.3 | 70B | Meta AI | [Meta AI 2024](https://ai.meta.com/blog/llama-3-3-large-language-model-family/) | | Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/) | | MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama) | | Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) | @@ -139,6 +140,7 @@ Every model is written from scratch to maximize performance and remove layers of | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | | SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | +| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | | StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | diff --git a/litgpt/config.py b/litgpt/config.py index 3a9370d2fd..9c64d2ae48 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -700,8 +700,31 @@ def norm_class(self) -> Type: rope_base=500000, rope_adjustments=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192) ), + # https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct/blob/main/config.json + dict( + name="Llama-3.3-70B-Instruct", + hf_config=dict(org="meta-llama", name="Llama-3.3-70B-Instruct"), + block_size=131072, + vocab_size=128000, + padded_vocab_size=128256, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=28672, + rope_base=500000, + rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192) + ), ] for c in llama_3: + if c["name"] == "Llama-3.3-70B-Instruct": + configs.append(c) + continue for kind in ("", "-Instruct"): copy = deepcopy(c) copy["name"] = c["name"].format(kind) @@ -1905,7 +1928,7 @@ def norm_class(self) -> Type: dict( name="Qwen2.5-Coder-1.5B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Coder-1.5B{}"), - block_size=131072, + block_size=32768, vocab_size=151643, padded_vocab_size=151936, n_layer=28, @@ -1947,7 +1970,7 @@ def norm_class(self) -> Type: dict( name="Qwen2.5-Coder-7B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Coder-7B{}"), - block_size=131072, + block_size=32768, vocab_size=151643, padded_vocab_size=152064, n_layer=28, @@ -1968,7 +1991,7 @@ def norm_class(self) -> Type: dict( name="Qwen2.5-Coder-14B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Coder-14B{}"), - block_size=131072, + block_size=32768, vocab_size=151643, padded_vocab_size=152064, n_layer=48, @@ -1989,7 +2012,7 @@ def norm_class(self) -> Type: dict( name="Qwen2.5-Coder-32B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Coder-32B{}"), - block_size=131072, + block_size=32768, vocab_size=151643, padded_vocab_size=152064, n_layer=64, @@ -2043,6 +2066,61 @@ def norm_class(self) -> Type: configs.extend(qwq) + +############# +# Salamandra +############# +salamandra = [ + # https://huggingface.co/BSC-LT/salamandra-2b-instruct/blob/main/config.json + dict( + name="salamandra-2b{}", + hf_config=dict(org="BSC-LT", name="salamandra-2b{}"), + block_size=8192, + vocab_size=256000, + padded_vocab_size=256000, + n_layer=24, + n_head=16, + n_embd=2048, + n_query_groups=16, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=5440, + norm_eps=1e-5, + rope_base=10000 + ), + # https://huggingface.co/BSC-LT/salamandra-7b-instruct/blob/main/config.json + dict( + name="salamandra-7b{}", + hf_config=dict(org="BSC-LT", name="salamandra-7b{}"), + block_size=8192, + vocab_size=256000, + padded_vocab_size=256000, + n_layer=32, + n_head=32, + n_embd=4096, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=11008, + norm_eps=1e-6, + rope_base=10000 + ), +] + +for c in salamandra: + for kind in ("", "-instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + + ############### # SmolLM2 ############### @@ -2116,4 +2194,5 @@ def norm_class(self) -> Type: copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) + name_to_config = {config["name"]: config for config in configs} diff --git a/litgpt/prompts.py b/litgpt/prompts.py index be433ad0d4..83a96ac43e 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -290,6 +290,11 @@ def apply(self, prompt: str, **kwargs: str) -> str: system_message = "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step." return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" +class Salamandra(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + system_message = "I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit. La meva base de coneixement es va actualitzar per última vegada l'agost de 2023." + return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + class SmolLM2(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: @@ -323,6 +328,7 @@ def apply(self, prompt: str, **kwargs: str) -> str: "qwen2.5": Qwen2_5, "qwq": QwQ, "smollm2": SmolLM2 # SmolLM uses a different template + "salamandra": Salamandra, } @@ -367,6 +373,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return QwQ() if re.search(r"SmolLM2.*-Instruct", model_name): return SmolLM2() + if re.search(r"salamandra-.*-instruct", model_name): + return Salamandra() return Default() diff --git a/litgpt/tokenizer.py b/litgpt/tokenizer.py index acfc0493a7..10f7d031f6 100644 --- a/litgpt/tokenizer.py +++ b/litgpt/tokenizer.py @@ -145,6 +145,9 @@ def decode(self, tensor: torch.Tensor) -> str: if len(tokens) == 1 and self.apply_decoding_fix: dummy_token_id = 33 # \x1e dummy_token = self.processor.decode([dummy_token_id]) + if dummy_token != "\x1e": + dummy_token_id = 165 # \x1e is different in salamandra tokenizers + dummy_token = self.processor.decode([dummy_token_id]) return self.processor.decode([dummy_token_id] + tokens)[len(dummy_token) :] return self.processor.decode(tokens) diff --git a/tests/test_model.py b/tests/test_model.py index 867f60d635..1a997f3134 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -223,6 +223,7 @@ def test_against_original_open_llama_3b(device, dtype): {"name": "Llama-3.1-8B-Instruct"}, {"name": "Llama-3.2-1B"}, {"name": "Llama-3.2-3B"}, + {"name": "Llama-3.3-70B-Instruct"}, ], ) @pytest.mark.parametrize( @@ -852,6 +853,66 @@ def test_against_original_qwen_2_5(model_name, device, dtype): torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b")) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_original_salamandra(model_name, device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + model_name, + padded_vocab_size=10000, + n_layer=2, + n_head=8, + n_embd=32, + n_query_groups=2, + intermediate_size=86, + ) + T = 5 + theirs_config = LlamaConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.bias, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = LlamaForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + @torch.inference_mode() @pytest.mark.parametrize("model_name", ("SmolLM2-135M", "SmolLM2-1.7B")) @pytest.mark.parametrize( @@ -911,6 +972,7 @@ def test_against_original_smollm2(model_name, device, dtype): theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + @RunIf(dynamo=True) @torch.inference_mode() def test_model_compile(): diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index c18a40bdda..fe276d3eac 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -20,6 +20,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 3.1 | 8B, 70B, 405B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 3.2 | 1B, 3B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/MODEL_CARD.md) | +| Llama 3.3 | 70B | Meta AI | [Meta AI 2024](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct) | | Llama 3.1 Nemotron | 70B | NVIDIA | [NVIDIA AI 2024](https://build.nvidia.com/nvidia/llama-3_1-nemotron-70b-instruct/modelcard) | | LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | | Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/) | @@ -40,6 +41,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) | | SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | +| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | | StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | @@ -63,6 +65,10 @@ The output is shown below: allenai/OLMo-1B-hf allenai/OLMo-7B-hf allenai/OLMo-7B-Instruct-hf +bsc-lt/salamandra-2b +bsc-lt/salamandra-2b-instruct +bsc-lt/salamandra-7b +bsc-lt/salamandra-7b-instruct codellama/CodeLlama-13b-hf codellama/CodeLlama-13b-Instruct-hf codellama/CodeLlama-13b-Python-hf @@ -141,6 +147,7 @@ meta-llama/Llama-3.2-1B meta-llama/Llama-3.2-1B-Instruct meta-llama/Llama-3.2-3B meta-llama/Llama-3.2-3B-Instruct +meta-llama/Llama-3.3-70B-Instruct meta-llama/Meta-Llama-3-70B meta-llama/Meta-Llama-3-70B-Instruct meta-llama/Meta-Llama-3-8B