diff --git a/docs/source/en/gguf.md b/docs/source/en/gguf.md index 359ed4d5e1e8..bc71630711ab 100644 --- a/docs/source/en/gguf.md +++ b/docs/source/en/gguf.md @@ -64,6 +64,7 @@ For now the supported model architectures are the architectures that have been v - LLaMa - Mistral - Qwen2 +- Qwen2Moe ## Example usage diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 7da09be841e1..c37fa5123335 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -117,6 +117,21 @@ "output.weight": "lm_head.weight", "output_norm": "model.norm", }, + "qwen2moe": { + "token_embd": "model.embed_tokens", + "blk": "model.layers", + "ffn_up": "mlp.up_proj", + "ffn_down": "mlp.down_proj", + "ffn_gate": "mlp.gate_proj", + "ffn_norm": "post_attention_layernorm", + "attn_norm": "input_layernorm", + "attn_q": "self_attn.q_proj", + "attn_v": "self_attn.v_proj", + "attn_k": "self_attn.k_proj", + "attn_output": "self_attn.o_proj", + "output.weight": "lm_head.weight", + "output_norm": "model.norm", + }, } @@ -161,6 +176,18 @@ "attention.layer_norm_rms_epsilon": "rms_norm_eps", "vocab_size": "vocab_size", }, + "qwen2moe": { + "context_length": "max_position_embeddings", + "block_count": "num_hidden_layers", + "feed_forward_length": "intermediate_size", + "embedding_length": "hidden_size", + "rope.dimension_count": None, + "rope.freq_base": "rope_theta", + "attention.head_count": "num_attention_heads", + "attention.head_count_kv": "num_key_value_heads", + "attention.layer_norm_rms_epsilon": "rms_norm_eps", + "vocab_size": "vocab_size", + }, "tokenizer": { "ggml.bos_token_id": "bos_token_id", "ggml.eos_token_id": "eos_token_id", @@ -579,7 +606,15 @@ def tokenizer(self, proto): bos_token = proto.tokens[proto.bos_token_id] if getattr(proto, "bos_token_id", None) is not None else None eos_token = proto.tokens[proto.bos_token_id] if getattr(proto, "eos_token_id", None) is not None else None - tokenizer = Tokenizer(BPE(bpe_vocab, merges, unk_token=unk_token, fuse_unk=True, byte_fallback=True)) + tokenizer = Tokenizer( + BPE( + bpe_vocab, + merges, + unk_token=unk_token, + fuse_unk=True, + byte_fallback=True, + ) + ) special_tokens = [] @@ -693,6 +728,7 @@ def converted(self) -> Tokenizer: GGUF_TO_FAST_CONVERTERS = { "llama": GGUFLlamaConverter, "qwen2": GGUFQwen2Converter, + "qwen2_moe": GGUFQwen2Converter, } diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 52b1068e003f..1d79f1afcbd1 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -96,6 +96,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): else: updated_architecture = architecture + if "qwen2moe" in architecture: + updated_architecture = "qwen2_moe" + if architecture not in GGUF_SUPPORTED_ARCHITECTURES: raise ValueError(f"Architecture {architecture} not supported") diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index e42900a1d51b..0bc1f797aa05 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -16,7 +16,12 @@ import unittest from transformers import AddedToken, AutoModelForCausalLM, AutoTokenizer -from transformers.testing_utils import require_gguf, require_torch_gpu, slow, torch_device +from transformers.testing_utils import ( + require_gguf, + require_torch_gpu, + slow, + torch_device, +) from transformers.utils import is_torch_available @@ -32,6 +37,7 @@ class GgufIntegrationTests(unittest.TestCase): model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF" + qwen2_moe_model_id = "RichardErkhov/Qwen_-_Qwen1.5-MoE-A2.7B-Chat-gguf" llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF" tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF" @@ -45,6 +51,7 @@ class GgufIntegrationTests(unittest.TestCase): q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf" q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf" + q4_0_qwen2_moe_model_id = "Qwen1.5-MoE-A2.7B-Chat.Q4_0.gguf" q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf" f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf" @@ -166,7 +173,10 @@ def test_f16(self): def test_mistral_q4_0(self): tokenizer = AutoTokenizer.from_pretrained(self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id) model = AutoModelForCausalLM.from_pretrained( - self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id, device_map="auto", torch_dtype=torch.float16 + self.mistral_model_id, + gguf_file=self.q4_0_mistral_model_id, + device_map="auto", + torch_dtype=torch.float16, ) text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) @@ -178,7 +188,10 @@ def test_mistral_q4_0(self): def test_qwen2_q4_0(self): tokenizer = AutoTokenizer.from_pretrained(self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id) model = AutoModelForCausalLM.from_pretrained( - self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id, device_map="auto", torch_dtype=torch.float16 + self.qwen2_model_id, + gguf_file=self.q4_0_qwen2_model_id, + device_map="auto", + torch_dtype=torch.float16, ) text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) @@ -187,6 +200,21 @@ def test_qwen2_q4_0(self): EXPECTED_TEXT = "Hello.jsoup\n\nI am a beginner" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_qwen2_moe_q4_0(self): + tokenizer = AutoTokenizer.from_pretrained(self.qwen2_moe_model_id, gguf_file=self.q4_0_qwen2_moe_model_id) + model = AutoModelForCausalLM.from_pretrained( + self.qwen2_moe_model_id, + gguf_file=self.q4_0_qwen2_moe_model_id, + device_map="auto", + torch_dtype=torch.float16, + ) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello everyone, I'm a newbie here and would like" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_llama3_q4_0_tokenizer(self): tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id) with tempfile.TemporaryDirectory() as tmpdirname: @@ -199,7 +227,10 @@ def test_llama3_q4_0_tokenizer(self): def test_llama3_q4_0(self): tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id) model = AutoModelForCausalLM.from_pretrained( - self.llama3_model_id, gguf_file=self.q4_llama3_model_id, device_map="auto", torch_dtype=torch.float16 + self.llama3_model_id, + gguf_file=self.q4_llama3_model_id, + device_map="auto", + torch_dtype=torch.float16, ) text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)