diff --git a/library/model_util.py b/library/model_util.py index be410a026..ea23d1b42 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -9,6 +9,7 @@ init_ipex() import diffusers +import importlib.metadata from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel from safetensors.torch import load_file, save_file @@ -572,6 +573,17 @@ def convert_ldm_clip_checkpoint_v1(checkpoint): if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + try: + vers = importlib.metadata.version("transformers").split(".") + except Exception: + vers = None + + if vers is not None and tuple(vers) <= ('4', '30', '2'): + # support checkpoint without position_ids (invalid checkpoint) + if "text_model.embeddings.position_ids" not in text_model_dict: + text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text + return text_model_dict + # remove position_ids for newer transformer, which causes error :( if "text_model.embeddings.position_ids" in text_model_dict: text_model_dict.pop("text_model.embeddings.position_ids")