From c874063408f48d0e34add2667ef1923bbb4ac68e Mon Sep 17 00:00:00 2001 From: Erik Scholz <Green-Sky@users.noreply.github.com> Date: Mon, 20 Nov 2023 15:34:17 +0100 Subject: [PATCH] fix: support bf16 lora weights (#82) --- models/convert.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/models/convert.py b/models/convert.py index 503b10da..62665eb9 100644 --- a/models/convert.py +++ b/models/convert.py @@ -101,7 +101,7 @@ def quantize_q5_1(x): def quantize_q8_0(x): assert x.shape[-1] % QK8_0 == 0 and x.shape[-1] > QK8_0 x = x.reshape(-1, QK8_0) - amax = np.max(np.abs(x), axis=-1, keepdims=True) + amax = np.max(np.abs(x), axis=-1, keepdims=True) d = amax / ((1 << 7) - 1) qs = (x / d).round().clip(min=-128, max=127).astype(np.int8) d = d.astype(np.float16).view(np.int8) @@ -178,7 +178,7 @@ def preprocess(state_dict): print("no alphas_cumprod in file, generate new one") alphas_cumprod = get_alpha_comprod() state_dict["alphas_cumprod"] = alphas_cumprod - + new_state_dict = {} for name, w in state_dict.items(): # ignore unused tensors @@ -192,7 +192,7 @@ def preprocess(state_dict): if skip: continue - # # convert BF16 to FP16 + # convert BF16 to FP16 if w.dtype == torch.bfloat16: w = w.to(torch.float16) @@ -251,7 +251,7 @@ def preprocess(state_dict): new_state_dict[new_name] = w print(f"preprocess {name} => {new_name}") continue - + # convert unet transformer linear to conv2d 1x1 if name.startswith("model.diffusion_model.") and (name.endswith("proj_in.weight") or name.endswith("proj_out.weight")): if len(w.shape) == 2: @@ -342,6 +342,11 @@ def preprocess_lora(state_dict): for name, w in state_dict.items(): if not isinstance(w, torch.Tensor): continue + + # convert BF16 to FP16 + if w.dtype == torch.bfloat16: + w = w.to(torch.float16) + name_without_network_parts, network_part = name.split(".", 1) new_name_without_network_parts = convert_diffusers_name_to_compvis(name_without_network_parts) if new_name_without_network_parts == None: @@ -421,6 +426,7 @@ def convert(model_path, out_type = None, out_file=None, lora=False): continue if name in unused_tensors: continue + data = state_dict[name].numpy() n_dims = len(data.shape) @@ -452,7 +458,7 @@ def convert(model_path, out_type = None, out_file=None, lora=False): else: data = data.astype(np.float32) ttype = "f32" - + print("Processing tensor: {} with shape {}, {} -> {}".format(name, data.shape, old_type, ttype)) # header