From 757765c77bac62c0292a45f41de2ce30332d10ab Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Wed, 15 Jan 2025 15:43:07 +0000 Subject: [PATCH 1/2] fix falcon tie_word_embeddings --- src/transformers/modeling_gguf_pytorch_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 0da06a1f582a..3024a523e2bd 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -400,9 +400,10 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_lo # Handle tie_word_embeddings, if lm_head.weight is not present in tensors, # tie_word_embeddings is true otherwise false + exceptions = ["falcon"] parsed_parameters["config"]["tie_word_embeddings"] = all( "output.weight" != tensor.name for tensor in reader.tensors - ) + ) or architecture in exceptions # List all key-value pairs in a columnized format for gguf_key, field in reader.fields.items(): From 2c59f0b89e897fe34203c8b62008e334302d638e Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Wed, 15 Jan 2025 15:56:51 +0000 Subject: [PATCH 2/2] fix style --- src/transformers/modeling_gguf_pytorch_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 3024a523e2bd..21385233a779 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -401,9 +401,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_lo # Handle tie_word_embeddings, if lm_head.weight is not present in tensors, # tie_word_embeddings is true otherwise false exceptions = ["falcon"] - parsed_parameters["config"]["tie_word_embeddings"] = all( - "output.weight" != tensor.name for tensor in reader.tensors - ) or architecture in exceptions + parsed_parameters["config"]["tie_word_embeddings"] = ( + all("output.weight" != tensor.name for tensor in reader.tensors) or architecture in exceptions + ) # List all key-value pairs in a columnized format for gguf_key, field in reader.fields.items():