From ad050107b3e1b2c19fc263ceae799189735299cd Mon Sep 17 00:00:00 2001 From: Jin Qiao <89779290+JinBridger@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:17:07 +0800 Subject: [PATCH] LLM: fix mpt load_low_bit issue (#10075) * fix * retry * retry --- python/llm/src/bigdl/llm/transformers/model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 93afdbe9c6a..1e8e431b4a7 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -361,7 +361,11 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs): cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm, torch_dtype=kwargs.get("torch_dtype", 'auto')) model.config.update({"bigdl_transformers_low_bit": q_k}) - model.config.update({"tie_word_embeddings": False}) + + # enable tie_word_embeddings for MPT + # refer to https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/modeling_mpt.py#L232 + if model.config.architectures[0] != 'MPTForCausalLM': + model.config.update({"tie_word_embeddings": False}) # add save_low_bit to pretrained model dynamically import types