Skip to content

Commit

Permalink
fix: NaN when train SDv21 model, added ControlNet
Browse files Browse the repository at this point in the history
  • Loading branch information
Linaqruf committed Feb 25, 2023
1 parent 60d32c4 commit 279ea6c
Show file tree
Hide file tree
Showing 8 changed files with 522 additions and 41 deletions.
2 changes: 1 addition & 1 deletion finetune/merge_captions_to_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ def main(args):
if args.caption_extention is not None:
args.caption_extension = args.caption_extention

main(args)
main(args)
2 changes: 1 addition & 1 deletion finetune/merge_dd_tags_to_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ def main(args):
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")

args = parser.parse_args()
main(args)
main(args)
181 changes: 149 additions & 32 deletions gen_img_diffusers.py

Large diffs are not rendered by default.

15 changes: 10 additions & 5 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,8 +1372,8 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):


def add_optimizer_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--optimizer_type", type=str, default="AdamW",
help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
parser.add_argument("--optimizer_type", type=str, default="",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")

# backward compatibility
parser.add_argument("--use_8bit_adam", action="store_true",
Expand Down Expand Up @@ -1532,11 +1532,16 @@ def get_optimizer(args, trainable_params):

optimizer_type = args.optimizer_type
if args.use_8bit_adam:
print(f"*** use_8bit_adam option is specified. optimizer_type is ignored / use_8bit_adamオプションが指定されているためoptimizer_typeは無視されます")
assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
optimizer_type = "AdamW8bit"

elif args.use_lion_optimizer:
print(f"*** use_lion_optimizer option is specified. optimizer_type is ignored / use_lion_optimizerオプションが指定されているためoptimizer_typeは無視されます")
assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
optimizer_type = "Lion"

if optimizer_type is None or optimizer_type == "":
optimizer_type = "AdamW"
optimizer_type = optimizer_type.lower()

# 引数を分解する:boolとfloat、tupleのみ対応
Expand All @@ -1557,7 +1562,7 @@ def get_optimizer(args, trainable_params):
value = tuple(value)

optimizer_kwargs[key] = value
print("optkwargs:", optimizer_kwargs)
# print("optkwargs:", optimizer_kwargs)

lr = args.learning_rate

Expand Down
5 changes: 5 additions & 0 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules)
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)

def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier

def load_weights(self, file):
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import load_file, safe_open
Expand Down
24 changes: 24 additions & 0 deletions tools/canny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import argparse
import cv2


def canny(args):
img = cv2.imread(args.input)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

canny_img = cv2.Canny(img, args.thres1, args.thres2)
# canny_img = 255 - canny_img

cv2.imwrite(args.output, canny_img)
print("done!")


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, default=None, help="input path")
parser.add_argument("--output", type=str, default=None, help="output path")
parser.add_argument("--thres1", type=int, default=32, help="thres1")
parser.add_argument("--thres2", type=int, default=224, help="thres2")

args = parser.parse_args()
canny(args)
Loading

0 comments on commit 279ea6c

Please sign in to comment.