From c874063408f48d0e34add2667ef1923bbb4ac68e Mon Sep 17 00:00:00 2001
From: Erik Scholz <Green-Sky@users.noreply.github.com>
Date: Mon, 20 Nov 2023 15:34:17 +0100
Subject: [PATCH] fix: support bf16 lora weights (#82)

---
 models/convert.py | 16 +++++++++++-----
 1 file changed, 11 insertions(+), 5 deletions(-)

diff --git a/models/convert.py b/models/convert.py
index 503b10da..62665eb9 100644
--- a/models/convert.py
+++ b/models/convert.py
@@ -101,7 +101,7 @@ def quantize_q5_1(x):
 def quantize_q8_0(x):
     assert x.shape[-1] % QK8_0 == 0 and x.shape[-1] > QK8_0
     x = x.reshape(-1, QK8_0)
-    amax = np.max(np.abs(x), axis=-1, keepdims=True) 
+    amax = np.max(np.abs(x), axis=-1, keepdims=True)
     d = amax / ((1 << 7) - 1)
     qs = (x / d).round().clip(min=-128, max=127).astype(np.int8)
     d = d.astype(np.float16).view(np.int8)
@@ -178,7 +178,7 @@ def preprocess(state_dict):
         print("no alphas_cumprod in file, generate new one")
         alphas_cumprod = get_alpha_comprod()
         state_dict["alphas_cumprod"] = alphas_cumprod
-    
+
     new_state_dict = {}
     for name, w in state_dict.items():
         # ignore unused tensors
@@ -192,7 +192,7 @@ def preprocess(state_dict):
         if skip:
             continue
 
-        # # convert BF16 to FP16
+        # convert BF16 to FP16
         if w.dtype == torch.bfloat16:
             w = w.to(torch.float16)
 
@@ -251,7 +251,7 @@ def preprocess(state_dict):
                 new_state_dict[new_name] = w
                 print(f"preprocess {name} => {new_name}")
             continue
-        
+
         # convert unet transformer linear to conv2d 1x1
         if name.startswith("model.diffusion_model.") and (name.endswith("proj_in.weight") or name.endswith("proj_out.weight")):
             if len(w.shape) == 2:
@@ -342,6 +342,11 @@ def preprocess_lora(state_dict):
     for name, w in state_dict.items():
         if not isinstance(w, torch.Tensor):
             continue
+
+        # convert BF16 to FP16
+        if w.dtype == torch.bfloat16:
+            w = w.to(torch.float16)
+
         name_without_network_parts, network_part = name.split(".", 1)
         new_name_without_network_parts = convert_diffusers_name_to_compvis(name_without_network_parts)
         if new_name_without_network_parts == None:
@@ -421,6 +426,7 @@ def convert(model_path, out_type = None, out_file=None, lora=False):
                 continue
             if name in unused_tensors:
                 continue
+
             data = state_dict[name].numpy()
 
             n_dims = len(data.shape)
@@ -452,7 +458,7 @@ def convert(model_path, out_type = None, out_file=None, lora=False):
             else:
                 data = data.astype(np.float32)
                 ttype = "f32"
-            
+
             print("Processing tensor: {} with shape {}, {} -> {}".format(name, data.shape, old_type, ttype))
 
             # header