From 216e7218102ae3220d7cda0c295e1a596622a38c Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 15:34:52 +0900 Subject: [PATCH 1/8] Add get_my_logger() --- library/utils.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/library/utils.py b/library/utils.py index 7d801a676..b491cd761 100644 --- a/library/utils.py +++ b/library/utils.py @@ -1,6 +1,21 @@ import threading from typing import * - +import logging def fire_in_thread(f, *args, **kwargs): - threading.Thread(target=f, args=args, kwargs=kwargs).start() \ No newline at end of file + threading.Thread(target=f, args=args, kwargs=kwargs).start() + +def get_my_logger(name: str): + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + + myformat = '%(asctime)s\t[%(levelname)s]\t%(filename)s:%(lineno)d\t%(message)s' + date_format = '%Y-%m-%d %H:%M:%S' + formatter = logging.Formatter(myformat, date_format) + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + + return logger From a88844321ceb43b320e210edc3d0af760d1b7fd5 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 15:41:35 +0900 Subject: [PATCH 2/8] Use logger instead of print --- fine_tune.py | 12 +- finetune/blip/blip.py | 4 +- finetune/clean_captions_and_tags.py | 40 +-- finetune/make_captions.py | 20 +- finetune/make_captions_by_git.py | 21 +- finetune/merge_captions_to_metadata.py | 18 +- finetune/merge_dd_tags_to_metadata.py | 18 +- finetune/prepare_buckets_latents.py | 24 +- finetune/tag_images_by_wd14_tagger.py | 26 +- gen_img_diffusers.py | 152 +++++------ library/config_util.py | 35 +-- library/custom_train_functions.py | 15 +- library/huggingface_util.py | 15 +- library/ipex/hijacks.py | 4 +- library/lpw_stable_diffusion.py | 3 +- library/model_util.py | 28 +- library/original_unet.py | 6 +- library/sai_model_spec.py | 4 +- library/sdxl_model_util.py | 23 +- library/sdxl_original_unet.py | 27 +- library/sdxl_train_util.py | 36 +-- library/slicing_vae.py | 23 +- library/train_util.py | 313 ++++++++++++----------- networks/check_lora_weights.py | 11 +- networks/control_net_lllite.py | 33 +-- networks/control_net_lllite_for_train.py | 41 +-- networks/dylora.py | 25 +- networks/extract_lora_from_dylora.py | 15 +- networks/extract_lora_from_models.py | 23 +- networks/lora.py | 103 ++++---- networks/lora_diffusers.py | 67 ++--- networks/lora_fa.py | 101 ++++---- networks/lora_interrogator.py | 22 +- networks/merge_lora.py | 35 +-- networks/merge_lora_old.py | 23 +- networks/oft.py | 17 +- networks/resize_lora.py | 16 +- networks/sdxl_merge_lora.py | 33 +-- networks/svd_merge_lora.py | 19 +- sdxl_gen_img.py | 154 +++++------ sdxl_minimal_inference.py | 12 +- sdxl_train.py | 18 +- sdxl_train_control_net_lllite.py | 21 +- sdxl_train_control_net_lllite_old.py | 21 +- sdxl_train_network.py | 9 +- tools/cache_latents.py | 17 +- tools/cache_text_encoder_outputs.py | 17 +- tools/convert_diffusers20_original_sd.py | 15 +- tools/detect_face_rotate.py | 20 +- tools/latent_upscaler.py | 19 +- tools/merge_models.py | 25 +- tools/original_control_net.py | 19 +- tools/resize_images_to_resolution.py | 7 +- tools/show_metadata.py | 6 +- train_controlnet.py | 15 +- train_db.py | 14 +- train_network.py | 19 +- train_textual_inversion.py | 8 +- train_textual_inversion_XTI.py | 66 ++--- 59 files changed, 1017 insertions(+), 936 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 893066f70..2d54b6fb3 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -20,6 +20,8 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler +from library.utils import get_my_logger +logger = get_my_logger(__name__) import library.train_util as train_util import library.config_util as config_util from library.config_util import ( @@ -51,11 +53,11 @@ def train(args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -88,7 +90,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" ) return @@ -99,7 +101,7 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -460,7 +462,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): train_util.save_sd_model_on_train_end( args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/blip/blip.py b/finetune/blip/blip.py index 7851fb08b..606b8cc9d 100644 --- a/finetune/blip/blip.py +++ b/finetune/blip/blip.py @@ -21,6 +21,8 @@ import os from urllib.parse import urlparse from timm.models.hub import download_cached_file +from library.utils import get_my_logger +logger = get_my_logger(__name__) class BLIP_Base(nn.Module): def __init__(self, @@ -235,6 +237,6 @@ def load_checkpoint(model,url_or_filename): del state_dict[key] msg = model.load_state_dict(state_dict,strict=False) - print('load checkpoint from %s'%url_or_filename) + logger.info('load checkpoint from %s'%url_or_filename) return model,msg diff --git a/finetune/clean_captions_and_tags.py b/finetune/clean_captions_and_tags.py index 68839eccc..fa441ba55 100644 --- a/finetune/clean_captions_and_tags.py +++ b/finetune/clean_captions_and_tags.py @@ -8,6 +8,8 @@ import re from tqdm import tqdm +from library.utils import get_my_logger +logger = get_my_logger(__name__) PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') @@ -36,13 +38,13 @@ def clean_tags(image_key, tags): tokens = tags.split(", rating") if len(tokens) == 1: # WD14 taggerのときはこちらになるのでメッセージは出さない - # print("no rating:") - # print(f"{image_key} {tags}") + # logger.info("no rating:") + # logger.info(f"{image_key} {tags}") pass else: if len(tokens) > 2: - print("multiple ratings:") - print(f"{image_key} {tags}") + logger.info("multiple ratings:") + logger.info(f"{image_key} {tags}") tags = tokens[0] tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策 @@ -124,43 +126,43 @@ def clean_caption(caption): def main(args): if os.path.exists(args.in_json): - print(f"loading existing metadata: {args.in_json}") + logger.info(f"loading existing metadata: {args.in_json}") with open(args.in_json, "rt", encoding='utf-8') as f: metadata = json.load(f) else: - print("no metadata / メタデータファイルがありません") + logger.error("no metadata / メタデータファイルがありません") return - print("cleaning captions and tags.") + logger.info("cleaning captions and tags.") image_keys = list(metadata.keys()) for image_key in tqdm(image_keys): tags = metadata[image_key].get('tags') if tags is None: - print(f"image does not have tags / メタデータにタグがありません: {image_key}") + logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}") else: org = tags tags = clean_tags(image_key, tags) metadata[image_key]['tags'] = tags if args.debug and org != tags: - print("FROM: " + org) - print("TO: " + tags) + logger.info("FROM: " + org) + logger.info("TO: " + tags) caption = metadata[image_key].get('caption') if caption is None: - print(f"image does not have caption / メタデータにキャプションがありません: {image_key}") + logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}") else: org = caption caption = clean_caption(caption) metadata[image_key]['caption'] = caption if args.debug and org != caption: - print("FROM: " + org) - print("TO: " + caption) + logger.info("FROM: " + org) + logger.info("TO: " + caption) # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") + logger.info(f"writing metadata: {args.out_json}") with open(args.out_json, "wt", encoding='utf-8') as f: json.dump(metadata, f, indent=2) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: @@ -178,10 +180,10 @@ def setup_parser() -> argparse.ArgumentParser: args, unknown = parser.parse_known_args() if len(unknown) == 1: - print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") - print("All captions and tags in the metadata are processed.") - print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") - print("メタデータ内のすべてのキャプションとタグが処理されます。") + logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") + logger.warning("All captions and tags in the metadata are processed.") + logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") + logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。") args.in_json = args.out_json args.out_json = unknown[0] elif len(unknown) > 0: diff --git a/finetune/make_captions.py b/finetune/make_captions.py index b20c41068..e18ee4b80 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -15,6 +15,8 @@ sys.path.append(os.path.dirname(__file__)) from blip.blip import blip_decoder import library.train_util as train_util +from library.utils import get_my_logger +logger = get_my_logger(__name__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -47,7 +49,7 @@ def __getitem__(self, idx): # convert to tensor temporarily so dataloader will accept it tensor = IMAGE_TRANSFORM(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") return None return (tensor, img_path) @@ -74,19 +76,19 @@ def main(args): args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path cwd = os.getcwd() - print("Current Working Directory is: ", cwd) + logger.info(f"Current Working Directory is: {cwd}") os.chdir("finetune") - print(f"load images from {args.train_data_dir}") + logger.info(f"load images from {args.train_data_dir}") train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") - print(f"loading BLIP caption: {args.caption_weights}") + logger.info(f"loading BLIP caption: {args.caption_weights}") model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json") model.eval() model = model.to(DEVICE) - print("BLIP loaded") + logger.info("BLIP loaded") # captioningする def run_batch(path_imgs): @@ -106,7 +108,7 @@ def run_batch(path_imgs): with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: f.write(caption + "\n") if args.debug: - print(image_path, caption) + logger.info(f'{image_path} {caption}') # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -136,7 +138,7 @@ def run_batch(path_imgs): raw_image = raw_image.convert("RGB") img_tensor = IMAGE_TRANSFORM(raw_image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue b_imgs.append((image_path, img_tensor)) @@ -146,7 +148,7 @@ def run_batch(path_imgs): if len(b_imgs) > 0: run_batch(b_imgs) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index b3c5cc423..777c08dc8 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -10,7 +10,8 @@ from transformers.generation.utils import GenerationMixin import library.train_util as train_util - +from library.utils import get_my_logger +logger = get_my_logger(__name__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -35,8 +36,8 @@ def remove_words(captions, debug): for pat in PATTERN_REPLACE: cap = pat.sub("", cap) if debug and cap != caption: - print(caption) - print(cap) + logger.info(caption) + logger.info(cap) removed_caps.append(cap) return removed_caps @@ -70,16 +71,16 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs) GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch """ - print(f"load images from {args.train_data_dir}") + logger.info(f"load images from {args.train_data_dir}") train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") # できればcacheに依存せず明示的にダウンロードしたい - print(f"loading GIT: {args.model_id}") + logger.info(f"loading GIT: {args.model_id}") git_processor = AutoProcessor.from_pretrained(args.model_id) git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) - print("GIT loaded") + logger.info("GIT loaded") # captioningする def run_batch(path_imgs): @@ -97,7 +98,7 @@ def run_batch(path_imgs): with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: f.write(caption + "\n") if args.debug: - print(image_path, caption) + logger.info(f"{image_path} {caption}") # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -126,7 +127,7 @@ def run_batch(path_imgs): if image.mode != "RGB": image = image.convert("RGB") except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue b_imgs.append((image_path, image)) @@ -137,7 +138,7 @@ def run_batch(path_imgs): if len(b_imgs) > 0: run_batch(b_imgs) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py index 241f6f902..07c929a9b 100644 --- a/finetune/merge_captions_to_metadata.py +++ b/finetune/merge_captions_to_metadata.py @@ -5,26 +5,28 @@ from tqdm import tqdm import library.train_util as train_util import os +from library.utils import get_my_logger +logger = get_my_logger(__name__) def main(args): assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" train_data_dir_path = Path(args.train_data_dir) image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") if args.in_json is None and Path(args.out_json).is_file(): args.in_json = args.out_json if args.in_json is not None: - print(f"loading existing metadata: {args.in_json}") + logger.info(f"loading existing metadata: {args.in_json}") metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") + logger.info("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") else: - print("new metadata will be created / 新しいメタデータファイルが作成されます") + logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") metadata = {} - print("merge caption texts to metadata json.") + logger.info("merge caption texts to metadata json.") for image_path in tqdm(image_paths): caption_path = image_path.with_suffix(args.caption_extension) caption = caption_path.read_text(encoding='utf-8').strip() @@ -38,12 +40,12 @@ def main(args): metadata[image_key]['caption'] = caption if args.debug: - print(image_key, caption) + logger.info(f"{image_key} {caption}") # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") + logger.info(f"writing metadata: {args.out_json}") Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py index db1bff6da..df9bbecea 100644 --- a/finetune/merge_dd_tags_to_metadata.py +++ b/finetune/merge_dd_tags_to_metadata.py @@ -5,26 +5,28 @@ from tqdm import tqdm import library.train_util as train_util import os +from library.utils import get_my_logger +logger = get_my_logger(__name__) def main(args): assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" train_data_dir_path = Path(args.train_data_dir) image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") if args.in_json is None and Path(args.out_json).is_file(): args.in_json = args.out_json if args.in_json is not None: - print(f"loading existing metadata: {args.in_json}") + logger.info(f"loading existing metadata: {args.in_json}") metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") + logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") else: - print("new metadata will be created / 新しいメタデータファイルが作成されます") + logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") metadata = {} - print("merge tags to metadata json.") + logger.info("merge tags to metadata json.") for image_path in tqdm(image_paths): tags_path = image_path.with_suffix(args.caption_extension) tags = tags_path.read_text(encoding='utf-8').strip() @@ -38,13 +40,13 @@ def main(args): metadata[image_key]['tags'] = tags if args.debug: - print(image_key, tags) + logger.info(f"{image_key} {tags}") # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") + logger.info(f"writing metadata: {args.out_json}") Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 1bccb1d3b..8cc1449e7 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -13,6 +13,8 @@ import library.model_util as model_util import library.train_util as train_util +from library.utils import get_my_logger +logger = get_my_logger(__name__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -51,22 +53,22 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive): def main(args): # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" if args.bucket_reso_steps % 8 > 0: - print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") + logger.warning(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") if args.bucket_reso_steps % 32 > 0: - print( + logger.warning( f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません" ) train_data_dir_path = Path(args.train_data_dir) image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") if os.path.exists(args.in_json): - print(f"loading existing metadata: {args.in_json}") + logger.info(f"loading existing metadata: {args.in_json}") with open(args.in_json, "rt", encoding="utf-8") as f: metadata = json.load(f) else: - print(f"no metadata / メタデータファイルがありません: {args.in_json}") + logger.error(f"no metadata / メタデータファイルがありません: {args.in_json}") return weight_dtype = torch.float32 @@ -89,7 +91,7 @@ def main(args): if not args.bucket_no_upscale: bucket_manager.make_buckets() else: - print( + logger.warning( "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" ) @@ -130,7 +132,7 @@ def process_batch(is_last): if image.mode != "RGB": image = image.convert("RGB") except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] @@ -183,15 +185,15 @@ def process_batch(is_last): for i, reso in enumerate(bucket_manager.resos): count = bucket_counts.get(reso, 0) if count > 0: - print(f"bucket {i} {reso}: {count}") + logger.info(f"bucket {i} {reso}: {count}") img_ar_errors = np.array(img_ar_errors) - print(f"mean ar error: {np.mean(img_ar_errors)}") + logger.info(f"mean ar error: {np.mean(img_ar_errors)}") # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") + logger.info(f"writing metadata: {args.out_json}") with open(args.out_json, "wt", encoding="utf-8") as f: json.dump(metadata, f, indent=2) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 965edd7e2..fd98c296c 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -11,6 +11,8 @@ from tqdm import tqdm import library.train_util as train_util +from library.utils import get_my_logger +logger = get_my_logger(__name__) # from wd14 tagger IMAGE_SIZE = 448 @@ -58,7 +60,7 @@ def __getitem__(self, idx): image = preprocess_image(image) tensor = torch.tensor(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") return None return (tensor, img_path) @@ -79,7 +81,7 @@ def main(args): # depreacatedの警告が出るけどなくなったらその時 # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 if not os.path.exists(args.model_dir) or args.force_download: - print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") + logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") files = FILES if args.onnx: files += FILES_ONNX @@ -95,7 +97,7 @@ def main(args): force_filename=file, ) else: - print("using existing wd14 tagger model") + logger.info("using existing wd14 tagger model") # 画像を読み込む if args.onnx: @@ -103,8 +105,8 @@ def main(args): import onnxruntime as ort onnx_path = f"{args.model_dir}/model.onnx" - print("Running wd14 tagger with onnx") - print(f"loading onnx model: {onnx_path}") + logger.info("Running wd14 tagger with onnx") + logger.info(f"loading onnx model: {onnx_path}") if not os.path.exists(onnx_path): raise Exception( @@ -121,7 +123,7 @@ def main(args): if args.batch_size != batch_size and type(batch_size) != str: # some rebatch model may use 'N' as dynamic axes - print( + logger.info( f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" ) args.batch_size = batch_size @@ -156,7 +158,7 @@ def main(args): train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") tag_freq = {} @@ -235,7 +237,7 @@ def run_batch(path_imgs): with open(caption_file, "wt", encoding="utf-8") as f: f.write(tag_text + "\n") if args.debug: - print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") + logger.info(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -267,7 +269,7 @@ def run_batch(path_imgs): image = image.convert("RGB") image = preprocess_image(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue b_imgs.append((image_path, image)) @@ -282,11 +284,11 @@ def run_batch(path_imgs): if args.frequency_tags: sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True) - print("\nTag frequencies:") + logger.info("Tag frequencies:") for tag, freq in sorted_tags: - print(f"{tag}: {freq}") + logger.info(f"{tag}: {freq}") - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index a596a0494..fb4d87ca4 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -109,6 +109,8 @@ from library.original_unet import FlashAttentionFunction from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI +from library.utils import get_my_logger +logger = get_my_logger(__name__) # scheduler: SCHEDULER_LINEAR_START = 0.00085 @@ -144,12 +146,12 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -157,7 +159,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_memory_efficient_attention(False, False) unet.set_use_sdpa(True) @@ -173,7 +175,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform def replace_vae_attn_to_memory_efficient(): - print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states, **kwargs): @@ -229,7 +231,7 @@ def forward_flash_attn_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_xformers(): - print("VAE: Attention.forward has been replaced to xformers") + logger.info("VAE: Attention.forward has been replaced to xformers") import xformers.ops def forward_xformers(self, hidden_states, **kwargs): @@ -285,7 +287,7 @@ def forward_xformers_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_sdpa(): - print("VAE: Attention.forward has been replaced to sdpa") + logger.info("VAE: Attention.forward has been replaced to sdpa") def forward_sdpa(self, hidden_states, **kwargs): residual = hidden_states @@ -689,7 +691,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") + logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0") negative_scale = None # get unconditional embeddings for classifier free guidance @@ -771,11 +773,11 @@ def __call__( clip_text_input = prompt_tokens if clip_text_input.shape[1] > self.tokenizer.model_max_length: # TODO 75文字を超えたら警告を出す? - print("trim text input", clip_text_input.shape) + logger.info(f"trim text input {clip_text_input.shape}") clip_text_input = torch.cat( [clip_text_input[:, : self.tokenizer.model_max_length - 1], clip_text_input[:, -1].unsqueeze(1)], dim=1 ) - print("trimmed", clip_text_input.shape) + logger.info(f"trimmed {clip_text_input.shape}") for i, clip_prompt in enumerate(clip_prompts): if clip_prompt is not None: # clip_promptがあれば上書きする @@ -1704,7 +1706,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: if word.strip() == "BREAK": # pad until next multiple of tokenizer's max token length pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length) - print(f"BREAK pad_len: {pad_len}") + logger.info(f"BREAK pad_len: {pad_len}") for i in range(pad_len): # v2のときEOSをつけるべきかどうかわからないぜ # if i == 0: @@ -1734,7 +1736,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: tokens.append(text_token) weights.append(text_weight) if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights @@ -2046,7 +2048,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): elif len(count_range) == 2: count_range = [int(count_range[0]), int(count_range[1])] else: - print(f"invalid count range: {count_range}") + logger.warning(f"invalid count range: {count_range}") count_range = [1, 1] if count_range[0] > count_range[1]: count_range = [count_range[1], count_range[0]] @@ -2116,7 +2118,7 @@ def replacer(): # def load_clip_l14_336(dtype): -# print(f"loading CLIP: {CLIP_ID_L14_336}") +# logger.info(f"loading CLIP: {CLIP_ID_L14_336}") # text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) # return text_encoder @@ -2163,9 +2165,9 @@ def main(args): # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") # モデルを読み込む if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う @@ -2175,10 +2177,10 @@ def main(args): use_stable_diffusion_format = os.path.isfile(args.ckpt) if use_stable_diffusion_format: - print("load StableDiffusion checkpoint") + logger.info("load StableDiffusion checkpoint") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) else: - print("load Diffusers pretrained models") + logger.info("load Diffusers pretrained models") loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) text_encoder = loading_pipe.text_encoder vae = loading_pipe.vae @@ -2200,21 +2202,21 @@ def main(args): # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") # # 置換するCLIPを読み込む # if args.replace_clip_l14_336: # text_encoder = load_clip_l14_336(dtype) - # print(f"large clip {CLIP_ID_L14_336} is loaded") + # logger.info(f"large clip {CLIP_ID_L14_336} is loaded") if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale: - print("prepare clip model") + logger.info("prepare clip model") clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype) else: clip_model = None if args.vgg16_guidance_scale > 0.0: - print("prepare resnet model") + logger.info("prepare resnet model") vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1) else: vgg16_model = None @@ -2226,7 +2228,7 @@ def main(args): replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) # tokenizerを読み込む - print("loading tokenizer") + logger.info("loading tokenizer") if use_stable_diffusion_format: tokenizer = train_util.load_tokenizer(args) @@ -2285,7 +2287,7 @@ def reset_sampler_noises(self, noises): self.sampler_noises = noises def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + # logger.info(f"replacing {shape} {len(self.sampler_noises)} {self.sampler_noise_index}") if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): noise = self.sampler_noises[self.sampler_noise_index] if shape != noise.shape: @@ -2294,7 +2296,7 @@ def randn(self, shape, device=None, dtype=None, layout=None, generator=None): noise = None if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) self.sampler_noise_index += 1 @@ -2325,7 +2327,7 @@ def __getattr__(self, item): # clip_sample=Trueにする if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - print("set clip_sample to True") + logger.info("set clip_sample to True") scheduler.config.clip_sample = True # deviceを決定する @@ -2375,7 +2377,7 @@ def __getattr__(self, item): network_merge = 0 for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) + logger.info(f"import network module: {network_module}") imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] @@ -2393,7 +2395,7 @@ def __getattr__(self, item): raise ValueError("No weight. Weight is required.") network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + logger.info(f"load network weights from: {network_weight}") if model_util.is_safetensors(network_weight) and args.network_show_meta: from safetensors.torch import safe_open @@ -2401,7 +2403,7 @@ def __getattr__(self, item): with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + logger.info(f"metadata for: {network_weight}: {metadata}") network, weights_sd = imported_module.create_network_from_weights( network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs @@ -2411,20 +2413,20 @@ def __getattr__(self, item): mergeable = network.is_mergeable() if network_merge and not mergeable: - print("network is not mergiable. ignore merge option.") + logger.warning("network is not mergiable. ignore merge option.") if not mergeable or i >= network_merge: # not merging network.apply_to(text_encoder, unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - print(f"weights are loaded: {info}") + logger.info(f"weights are loaded: {info}") if args.opt_channels_last: network.to(memory_format=torch.channels_last) network.to(dtype).to(device) if network_pre_calc: - print("backup original weights") + logger.info("backup original weights") network.backup_weights() networks.append(network) @@ -2438,7 +2440,7 @@ def __getattr__(self, item): # upscalerの指定があれば取得する upscaler = None if args.highres_fix_upscaler: - print("import upscaler module:", args.highres_fix_upscaler) + logger.info(f"import upscaler module {args.highres_fix_upscaler}") imported_module = importlib.import_module(args.highres_fix_upscaler) us_kwargs = {} @@ -2447,7 +2449,7 @@ def __getattr__(self, item): key, value = net_arg.split("=") us_kwargs[key] = value - print("create upscaler") + logger.info("create upscaler") upscaler = imported_module.create_upscaler(**us_kwargs) upscaler.to(dtype).to(device) @@ -2464,7 +2466,7 @@ def __getattr__(self, item): control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) if args.opt_channels_last: - print(f"set optimizing: channels last") + logger.info(f"set optimizing: channels last") text_encoder.to(memory_format=torch.channels_last) vae.to(memory_format=torch.channels_last) unet.to(memory_format=torch.channels_last) @@ -2496,7 +2498,7 @@ def __getattr__(self, item): args.vgg16_guidance_layer, ) pipe.set_control_nets(control_nets) - print("pipeline is ready.") + logger.info("pipeline is ready.") if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() @@ -2535,7 +2537,7 @@ def __getattr__(self, item): ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") assert ( min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1 ), f"token ids is not ordered" @@ -2594,7 +2596,7 @@ def __getattr__(self, item): ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + logger.info(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") # if num_vectors_per_token > 1: pipe.add_token_replacement(token_ids[0], token_ids) @@ -2619,7 +2621,7 @@ def __getattr__(self, item): # promptを取得する if args.from_file is not None: - print(f"reading prompts from {args.from_file}") + logger.info(f"reading prompts from {args.from_file}") with open(args.from_file, "r", encoding="utf-8") as f: prompt_list = f.read().splitlines() prompt_list = [d for d in prompt_list if len(d.strip()) > 0] @@ -2648,7 +2650,7 @@ def load_images(path): for p in paths: image = Image.open(p) if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") + logger.info(f"convert image to RGB from {image.mode}: {p}") image = image.convert("RGB") images.append(image) @@ -2664,24 +2666,24 @@ def resize_images(imgs, size): return resized if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") + logger.info(f"load image for img2img: {args.image_path}") init_images = load_images(args.image_path) assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") + logger.info(f"loaded {len(init_images)} images for img2img") else: init_images = None if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") + logger.info(f"load mask for inpainting: {args.mask_path}") mask_images = load_images(args.mask_path) assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") + logger.info(f"loaded {len(mask_images)} mask images for inpainting") else: mask_images = None # promptがないとき、画像のPngInfoから取得する if init_images is not None and len(prompt_list) == 0 and not args.interactive: - print("get prompts from images' meta data") + logger.info("get prompts from images' meta data") for img in init_images: if "prompt" in img.text: prompt = img.text["prompt"] @@ -2710,17 +2712,17 @@ def resize_images(imgs, size): h = int(h * args.highres_fix_scale + 0.5) if init_images is not None: - print(f"resize img2img source images to {w}*{h}") + logger.info(f"resize img2img source images to {w}*{h}") init_images = resize_images(init_images, (w, h)) if mask_images is not None: - print(f"resize img2img mask images to {w}*{h}") + logger.info(f"resize img2img mask images to {w}*{h}") mask_images = resize_images(mask_images, (w, h)) regional_network = False if networks and mask_images: # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 regional_network = True - print("use mask as region") + logger.info("use mask as region") size = None for i, network in enumerate(networks): @@ -2745,14 +2747,14 @@ def resize_images(imgs, size): prev_image = None # for VGG16 guided if args.guide_image_path is not None: - print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") + logger.info(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") guide_images = [] for p in args.guide_image_path: guide_images.extend(load_images(p)) - print(f"loaded {len(guide_images)} guide images for guidance") + logger.info(f"loaded {len(guide_images)} guide images for guidance") if len(guide_images) == 0: - print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + logger.info(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") guide_images = None else: guide_images = None @@ -2778,7 +2780,7 @@ def resize_images(imgs, size): max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") + logger.info(f"iteration {gen_iter+1}/{args.n_iter}") iter_seed = random.randint(0, 0x7FFFFFFF) # shuffle prompt list @@ -2794,7 +2796,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - print("process 1st stage") + logger.info("process 1st stage") batch_1st = [] for _, base, ext in batch: width_1st = int(ext.width * args.highres_fix_scale + 0.5) @@ -2820,7 +2822,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") + logger.info("process 2nd stage") width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height if upscaler: @@ -2971,7 +2973,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): n.restore_weights() for n in networks: n.pre_calculation() - print("pre-calculation... done") + logger.info("pre-calculation... done") images = pipe( prompts, @@ -3038,7 +3040,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): cv2.waitKey() cv2.destroyAllWindows() except ImportError: - print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + logger.info("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") return images @@ -3051,7 +3053,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # interactive valid = False while not valid: - print("\nType prompt:") + logger.info("\nType prompt:") try: raw_prompt = input() except EOFError: @@ -3087,38 +3089,38 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") for parg in prompt_args[1:]: try: m = re.match(r"w (\d+)", parg, re.IGNORECASE) if m: width = int(m.group(1)) - print(f"width: {width}") + logger.info(f"width: {width}") continue m = re.match(r"h (\d+)", parg, re.IGNORECASE) if m: height = int(m.group(1)) - print(f"height: {height}") + logger.info(f"height: {height}") continue m = re.match(r"s (\d+)", parg, re.IGNORECASE) if m: # steps steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") + logger.info(f"steps: {steps}") continue m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) if m: # seed seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") + logger.info(f"seeds: {seeds}") continue m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) if m: # scale scale = float(m.group(1)) - print(f"scale: {scale}") + logger.info(f"scale: {scale}") continue m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) @@ -3127,25 +3129,25 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): negative_scale = None else: negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") + logger.info(f"negative scale: {negative_scale}") continue m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) if m: # strength strength = float(m.group(1)) - print(f"strength: {strength}") + logger.info(f"strength: {strength}") continue m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") + logger.info(f"negative prompt: {negative_prompt}") continue m = re.match(r"c (.+)", parg, re.IGNORECASE) if m: # clip prompt clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") + logger.info(f"clip prompt: {clip_prompt}") continue m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) @@ -3153,12 +3155,12 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): network_muls = [float(v) for v in m.group(1).split(",")] while len(network_muls) < len(networks): network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") + logger.info(f"network mul: {network_muls}") continue except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + logger.info(f"Exception in parsing / 解析エラー: {parg}") + logger.info(ex) # prepare seed if seeds is not None: # given in prompt @@ -3170,7 +3172,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): if len(predefined_seeds) > 0: seed = predefined_seeds.pop(0) else: - print("predefined seeds are exhausted") + logger.info("predefined seeds are exhausted") seed = None elif args.iter_same_seed: seed = iter_seed @@ -3180,7 +3182,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): if seed is None: seed = random.randint(0, 0x7FFFFFFF) if args.interactive: - print(f"seed: {seed}") + logger.info(f"seed: {seed}") # prepare init image, guide image and mask init_image = mask_image = guide_image = None @@ -3196,7 +3198,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): width = width - width % 32 height = height - height % 32 if width != init_image.size[0] or height != init_image.size[1]: - print( + logger.info( f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" ) @@ -3212,9 +3214,9 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): guide_image = guide_images[global_step % len(guide_images)] elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0: if prev_image is None: - print("Generate 1st image without guide image.") + logger.info("Generate 1st image without guide image.") else: - print("Use previous image as guide image.") + logger.info("Use previous image as guide image.") guide_image = prev_image if regional_network: @@ -3256,7 +3258,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): process_batch(batch_data, highres_fix) batch_data.clear() - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/library/config_util.py b/library/config_util.py index e8e0fda7c..382906666 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -39,7 +39,8 @@ ControlNetDataset, DatasetGroup, ) - +from .utils import get_my_logger +logger = get_my_logger(__name__) def add_config_arguments(parser: argparse.ArgumentParser): parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") @@ -318,7 +319,7 @@ def sanitize_user_config(self, user_config: dict) -> dict: return self.user_config_validator(user_config) except MultipleInvalid: # TODO: エラー発生時のメッセージをわかりやすくする - print("Invalid user config / ユーザ設定の形式が正しくないようです") + logger.error("Invalid user config / ユーザ設定の形式が正しくないようです") raise # NOTE: In nature, argument parser result is not needed to be sanitize @@ -328,7 +329,7 @@ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> return self.argparse_config_validator(argparse_namespace) except MultipleInvalid: # XXX: this should be a bug - print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") + logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") raise # NOTE: value would be overwritten by latter dict if there is already the same key @@ -430,7 +431,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) datasets.append(dataset) - # print info + # logger.info info info = "" for i, dataset in enumerate(datasets): is_dreambooth = isinstance(dataset, DreamBoothDataset) @@ -484,13 +485,13 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu metadata_file: {subset.metadata_file} \n"""), " ") - print(info) + logger.info(f'{info}') # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): - print(f"[Dataset {i}]") + logger.info(f"[Dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) @@ -503,7 +504,7 @@ def extract_dreambooth_params(name: str) -> Tuple[int, str]: try: n_repeats = int(tokens[0]) except ValueError as e: - print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") + logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") return 0, "" caption_by_folder = '_'.join(tokens[1:]) return n_repeats, caption_by_folder @@ -568,13 +569,13 @@ def load_user_config(file: str) -> dict: with open(file, 'r') as f: config = json.load(f) except Exception: - print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") raise elif file.name.lower().endswith('.toml'): try: config = toml.load(file) except Exception: - print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") raise else: raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") @@ -597,21 +598,21 @@ def load_user_config(file: str) -> dict: argparse_namespace = parser.parse_args(remain) train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) - print("[argparse_namespace]") - print(vars(argparse_namespace)) + logger.info("[argparse_namespace]") + logger.info(f'{vars(argparse_namespace)}') user_config = load_user_config(config_args.dataset_config) - print("\n[user_config]") - print(user_config) + logger.info("\n[user_config]") + logger.info(f'{user_config}') sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout) sanitized_user_config = sanitizer.sanitize_user_config(user_config) - print("\n[sanitized_user_config]") - print(sanitized_user_config) + logger.info("\n[sanitized_user_config]") + logger.info(f'{sanitized_user_config}') blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) - print("\n[blueprint]") - print(blueprint) + logger.info("\n[blueprint]") + logger.info(f'{blueprint}') diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 28b625d30..043d06e7f 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -3,7 +3,8 @@ import random import re from typing import List, Optional, Union - +from .utils import get_my_logger +logger = get_my_logger(__name__) def prepare_scheduler_for_custom_training(noise_scheduler, device): if hasattr(noise_scheduler, "all_snr"): @@ -21,7 +22,7 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device): def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): # fix beta: zero terminal SNR - print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") + logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") def enforce_zero_terminal_snr(betas): # Convert betas to alphas_bar_sqrt @@ -49,8 +50,8 @@ def enforce_zero_terminal_snr(betas): alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) - # print("original:", noise_scheduler.betas) - # print("fixed:", betas) + # logger.info(f"original: {noise_scheduler.betas}") + # logger.info(f"fixed: {betas}") noise_scheduler.betas = betas noise_scheduler.alphas = alphas @@ -76,13 +77,13 @@ def get_snr_scale(timesteps, noise_scheduler): snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 scale = snr_t / (snr_t + 1) # # show debug info - # print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") + # logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") return scale def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss): scale = get_snr_scale(timesteps, noise_scheduler) - # print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") + # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") loss = loss + loss / scale * v_pred_like_loss return loss @@ -265,7 +266,7 @@ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int): tokens.append(text_token) weights.append(text_weight) if truncated: - print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights diff --git a/library/huggingface_util.py b/library/huggingface_util.py index 376fdb1e6..dd3d03bbb 100644 --- a/library/huggingface_util.py +++ b/library/huggingface_util.py @@ -4,7 +4,8 @@ import argparse import os from library.utils import fire_in_thread - +from library.utils import get_my_logger +logger = get_my_logger(__name__) def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None): api = HfApi( @@ -33,9 +34,9 @@ def upload( try: api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので - print("===========================================") - print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") - print("===========================================") + logger.error("===========================================") + logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") + logger.error("===========================================") is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir()) @@ -56,9 +57,9 @@ def uploader(): path_in_repo=path_in_repo, ) except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので - print("===========================================") - print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") - print("===========================================") + logger.error("===========================================") + logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") + logger.error("===========================================") if args.async_upload and not force_sync_upload: fire_in_thread(uploader) diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 77ed5419a..ba8af432a 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -2,6 +2,8 @@ import importlib import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +from library.utils import get_my_logger +logger = get_my_logger(__name__) # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return @@ -65,7 +67,7 @@ def _shutdown_workers(self): class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument if isinstance(device_ids, list) and len(device_ids) > 1: - print("IPEX backend doesn't support DataParallel on multiple XPU devices") + logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices") return module.to("xpu") def return_null_context(*args, **kwargs): # pylint: disable=unused-argument diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 9dce91a76..9fb63985d 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -17,7 +17,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.utils import logging - try: from diffusers.utils import PIL_INTERPOLATION except ImportError: @@ -646,7 +645,7 @@ def check_inputs(self, prompt, height, width, strength, callback_steps): raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: - print(height, width) + logger.info(f'{height} {width}') raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( diff --git a/library/model_util.py b/library/model_util.py index 00a3c0495..1d1e6c3de 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -16,6 +16,8 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel from safetensors.torch import load_file, save_file from library.original_unet import UNet2DConditionModel +from library.utils import get_my_logger +logger = get_my_logger(__name__) # DiffUsers版StableDiffusionのモデルパラメータ NUM_TRAIN_TIMESTEPS = 1000 @@ -947,7 +949,7 @@ def convert_vae_state_dict(vae_state_dict): for k, v in new_state_dict.items(): for weight_name in weights_to_convert: if f"mid.attn_1.{weight_name}.weight" in k: - # print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") + # logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict @@ -1005,7 +1007,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt unet = UNet2DConditionModel(**unet_config).to(device) info = unet.load_state_dict(converted_unet_checkpoint) - print("loading u-net:", info) + logger.info(f"loading u-net: {info}") # Convert the VAE model. vae_config = create_vae_diffusers_config() @@ -1013,7 +1015,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt vae = AutoencoderKL(**vae_config).to(device) info = vae.load_state_dict(converted_vae_checkpoint) - print("loading vae:", info) + logger.info(f"loading vae: {info}") # convert text_model if v2: @@ -1047,7 +1049,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt # logging.set_verbosity_error() # don't show annoying warning # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) # logging.set_verbosity_warning() - # print(f"config: {text_model.config}") + # logger.info(f"config: {text_model.config}") cfg = CLIPTextConfig( vocab_size=49408, hidden_size=768, @@ -1070,7 +1072,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt ) text_model = CLIPTextModel._from_config(cfg) info = text_model.load_state_dict(converted_text_encoder_checkpoint) - print("loading text encoder:", info) + logger.info(f"loading text encoder: {info}") return text_model, vae, unet @@ -1145,7 +1147,7 @@ def convert_key(key): # 最後の層などを捏造するか if make_dummy_weights: - print("make dummy weights for resblock.23, text_projection and logit scale.") + logger.info("make dummy weights for resblock.23, text_projection and logit scale.") keys = list(new_sd.keys()) for key in keys: if key.startswith("transformer.resblocks.22."): @@ -1259,14 +1261,14 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod def load_vae(vae_id, dtype): - print(f"load VAE: {vae_id}") + logger.info(f"load VAE: {vae_id}") if os.path.isdir(vae_id) or not os.path.isfile(vae_id): # Diffusers local/remote try: vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) except EnvironmentError as e: - print(f"exception occurs in loading vae: {e}") - print("retry with subfolder='vae'") + logger.error(f"exception occurs in loading vae: {e}") + logger.error("retry with subfolder='vae'") vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) return vae @@ -1338,13 +1340,13 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64) if __name__ == "__main__": resos = make_bucket_resolutions((512, 768)) - print(len(resos)) - print(resos) + logger.info(f"{len(resos)}") + logger.info(f"{resos}") aspect_ratios = [w / h for w, h in resos] - print(aspect_ratios) + logger.info(f"{aspect_ratios}") ars = set() for ar in aspect_ratios: if ar in ars: - print("error! duplicate ar:", ar) + logger.error(f"error! duplicate ar: {ar}") ars.add(ar) diff --git a/library/original_unet.py b/library/original_unet.py index 240b85951..363c42f1d 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -113,6 +113,8 @@ from torch import nn from torch.nn import functional as F from einops import rearrange +from library.utils import get_my_logger +logger = get_my_logger(__name__) BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0] @@ -1322,7 +1324,7 @@ def __init__( ): super().__init__() assert sample_size is not None, "sample_size must be specified" - print( + logger.info( f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}" ) @@ -1456,7 +1458,7 @@ def set_use_sdpa(self, sdpa: bool) -> None: def set_gradient_checkpointing(self, value=False): modules = self.down_blocks + [self.mid_block] + self.up_blocks for module in modules: - print(module.__class__.__name__, module.gradient_checkpointing, "->", value) + logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}") module.gradient_checkpointing = value # endregion diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index 472686ba4..d53a7959c 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -5,6 +5,8 @@ import os from typing import List, Optional, Tuple, Union import safetensors +from library.utils import get_my_logger +logger = get_my_logger(__name__) r""" # Metadata Example @@ -231,7 +233,7 @@ def build_metadata( # # assert all values are filled # assert all([v is not None for v in metadata.values()]), metadata if not all([v is not None for v in metadata.values()]): - print(f"Internal error: some metadata values are None: {metadata}") + logger.error(f"Internal error: some metadata values are None: {metadata}") return metadata diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 2f0154cae..209c940db 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -7,7 +7,8 @@ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from library import model_util from library import sdxl_original_unet - +from .utils import get_my_logger +logger = get_my_logger(__name__) VAE_SCALE_FACTOR = 0.13025 MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0" @@ -184,20 +185,20 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty checkpoint = None # U-Net - print("building U-Net") + logger.info("building U-Net") with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() - print("loading U-Net from checkpoint") + logger.info("loading U-Net from checkpoint") unet_sd = {} for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype) - print("U-Net: ", info) + logger.info(f"U-Net: {info}") # Text Encoders - print("building text encoders") + logger.info("building text encoders") # Text Encoder 1 is same to Stability AI's SDXL text_model1_cfg = CLIPTextConfig( @@ -250,7 +251,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty with init_empty_weights(): text_model2 = CLIPTextModelWithProjection(text_model2_cfg) - print("loading text encoders from checkpoint") + logger.info("loading text encoders from checkpoint") te1_sd = {} te2_sd = {} for k in list(state_dict.keys()): @@ -264,22 +265,22 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 - print("text encoder 1:", info1) + logger.info(f"text encoder 1: {info1}") converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32 - print("text encoder 2:", info2) + logger.info(f"text encoder 2: {info2}") # prepare vae - print("building VAE") + logger.info("building VAE") vae_config = model_util.create_vae_diffusers_config() with init_empty_weights(): vae = AutoencoderKL(**vae_config) - print("loading VAE from checkpoint") + logger.info("loading VAE from checkpoint") converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config) info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype) - print("VAE:", info) + logger.info(f"VAE: {info}") ckpt_info = (epoch, global_step) if epoch is not None else None return text_model1, text_model2, vae, unet, logit_scale, ckpt_info diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 26a0af319..c0a4b3af3 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -30,7 +30,8 @@ from torch import nn from torch.nn import functional as F from einops import rearrange - +from .utils import get_my_logger +logger = get_my_logger(__name__) IN_CHANNELS: int = 4 OUT_CHANNELS: int = 4 @@ -315,7 +316,7 @@ def forward_body(self, x, emb): def forward(self, x, emb): if self.training and self.gradient_checkpointing: - # print("ResnetBlock2D: gradient_checkpointing") + # logger.info("ResnetBlock2D: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -349,7 +350,7 @@ def forward_body(self, hidden_states): def forward(self, hidden_states): if self.training and self.gradient_checkpointing: - # print("Downsample2D: gradient_checkpointing") + # logger.info("Downsample2D: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -636,7 +637,7 @@ def forward_body(self, hidden_states, context=None, timestep=None): def forward(self, hidden_states, context=None, timestep=None): if self.training and self.gradient_checkpointing: - # print("BasicTransformerBlock: checkpointing") + # logger.info("BasicTransformerBlock: checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -779,7 +780,7 @@ def forward_body(self, hidden_states, output_size=None): def forward(self, hidden_states, output_size=None): if self.training and self.gradient_checkpointing: - # print("Upsample2D: gradient_checkpointing") + # logger.info("Upsample2D: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -1029,7 +1030,7 @@ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> N for block in blocks: for module in block: if hasattr(module, "set_use_memory_efficient_attention"): - # print(module.__class__.__name__) + # logger.info(module.__class__.__name__) module.set_use_memory_efficient_attention(xformers, mem_eff) def set_use_sdpa(self, sdpa: bool) -> None: @@ -1044,7 +1045,7 @@ def set_gradient_checkpointing(self, value=False): for block in blocks: for module in block.modules(): if hasattr(module, "gradient_checkpointing"): - # print(module.__class__.__name__, module.gradient_checkpointing, "->", value) + # logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}") module.gradient_checkpointing = value # endregion @@ -1066,7 +1067,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): def call_module(module, h, emb, context): x = h for layer in module: - # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) + # logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) if isinstance(layer, ResnetBlock2D): x = layer(x, emb) elif isinstance(layer, Transformer2DModel): @@ -1096,7 +1097,7 @@ def call_module(module, h, emb, context): if __name__ == "__main__": import time - print("create unet") + logger.info("create unet") unet = SdxlUNet2DConditionModel() unet.to("cuda") @@ -1105,7 +1106,7 @@ def call_module(module, h, emb, context): unet.train() # 使用メモリ量確認用の疑似学習ループ - print("preparing optimizer") + logger.info("preparing optimizer") # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working @@ -1120,12 +1121,12 @@ def call_module(module, h, emb, context): scaler = torch.cuda.amp.GradScaler(enabled=True) - print("start training") + logger.info("start training") steps = 10 batch_size = 1 for step in range(steps): - print(f"step {step}") + logger.info(f"step {step}") if step == 1: time_start = time.perf_counter() @@ -1145,4 +1146,4 @@ def call_module(module, h, emb, context): optimizer.zero_grad(set_to_none=True) time_end = time.perf_counter() - print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") + logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f637d9931..aa7ad01d2 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -9,6 +9,8 @@ from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline +from .utils import get_my_logger +logger = get_my_logger(__name__) TOKENIZER1_PATH = "openai/clip-vit-large-patch14" TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" @@ -21,7 +23,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: - print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") ( load_stable_diffusion_format, @@ -64,7 +66,7 @@ def _load_target_model( load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: - print(f"load StableDiffusion checkpoint: {name_or_path}") + logger.info(f"load StableDiffusion checkpoint: {name_or_path}") ( text_encoder1, text_encoder2, @@ -78,7 +80,7 @@ def _load_target_model( from diffusers import StableDiffusionXLPipeline variant = "fp16" if weight_dtype == torch.float16 else None - print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") + logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") try: try: pipe = StableDiffusionXLPipeline.from_pretrained( @@ -86,12 +88,12 @@ def _load_target_model( ) except EnvironmentError as ex: if variant is not None: - print("try to load fp32 model") + logger.info("try to load fp32 model") pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None) else: raise ex except EnvironmentError as ex: - print( + logger.error( f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" ) raise ex @@ -114,7 +116,7 @@ def _load_target_model( with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype) - print("U-Net converted to original U-Net") + logger.info("U-Net converted to original U-Net") logit_scale = None ckpt_info = None @@ -122,13 +124,13 @@ def _load_target_model( # VAEを読み込む if vae_path is not None: vae = model_util.load_vae(vae_path, weight_dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info def load_tokenizers(args: argparse.Namespace): - print("prepare tokenizers") + logger.info("prepare tokenizers") original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH] tokeniers = [] @@ -137,14 +139,14 @@ def load_tokenizers(args: argparse.Namespace): if args.tokenizer_cache_dir: local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) if os.path.exists(local_tokenizer_path): - print(f"load tokenizer from cache: {local_tokenizer_path}") + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) if tokenizer is None: tokenizer = CLIPTokenizer.from_pretrained(original_path) if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - print(f"save Tokenizer to cache: {local_tokenizer_path}") + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") tokenizer.save_pretrained(local_tokenizer_path) if i == 1: @@ -153,7 +155,7 @@ def load_tokenizers(args: argparse.Namespace): tokeniers.append(tokenizer) if hasattr(args, "max_token_length") and args.max_token_length is not None: - print(f"update token length: {args.max_token_length}") + logger.info(f"update token length: {args.max_token_length}") return tokeniers @@ -334,23 +336,23 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" if args.v_parameterization: - print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") + logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") if args.clip_skip is not None: - print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") + logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") # if args.multires_noise_iterations: - # print( + # logger.info( # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" # ) # else: # if args.noise_offset is None: # args.noise_offset = DEFAULT_NOISE_OFFSET # elif args.noise_offset != DEFAULT_NOISE_OFFSET: - # print( + # logger.info( # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" # ) - # print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") + # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") assert ( not hasattr(args, "weighted_captions") or not args.weighted_captions @@ -359,7 +361,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin if supportTextEncoderCaching: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: args.cache_text_encoder_outputs = True - print( + logger.warning( "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" ) diff --git a/library/slicing_vae.py b/library/slicing_vae.py index 31b2bd0a4..b8e75fa44 100644 --- a/library/slicing_vae.py +++ b/library/slicing_vae.py @@ -26,7 +26,8 @@ from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution from diffusers.models.autoencoder_kl import AutoencoderKLOutput - +from .utils import get_my_logger +logger = get_my_logger(__name__) def slice_h(x, num_slices): # slice with pad 1 both sides: to eliminate side effect of padding of conv2d @@ -89,7 +90,7 @@ def resblock_forward(_self, num_slices, input_tensor, temb): # sliced_tensor = torch.chunk(x, num_div, dim=1) # sliced_weight = torch.chunk(norm.weight, num_div, dim=0) # sliced_bias = torch.chunk(norm.bias, num_div, dim=0) - # print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape) + # logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape) # normed_tensor = [] # for i in range(num_div): # n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps) @@ -243,7 +244,7 @@ def forward(*args, **kwargs): self.num_slices = num_slices div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす - # print(f"initial divisor: {div}") + # logger.info(f"initial divisor: {div}") if div >= 2: div = int(div) for resnet in self.mid_block.resnets: @@ -253,11 +254,11 @@ def forward(*args, **kwargs): for i, down_block in enumerate(self.down_blocks[::-1]): if div >= 2: div = int(div) - # print(f"down block: {i} divisor: {div}") + # logger.info(f"down block: {i} divisor: {div}") for resnet in down_block.resnets: resnet.forward = wrapper(resblock_forward, resnet, div) if down_block.downsamplers is not None: - # print("has downsample") + # logger.info("has downsample") for downsample in down_block.downsamplers: downsample.forward = wrapper(self.downsample_forward, downsample, div * 2) div *= 2 @@ -307,7 +308,7 @@ def forward(self, x): def downsample_forward(self, _self, num_slices, hidden_states): assert hidden_states.shape[1] == _self.channels assert _self.use_conv and _self.padding == 0 - print("downsample forward", num_slices, hidden_states.shape) + logger.info(f"downsample forward {num_slices} {hidden_states.shape}") org_device = hidden_states.device cpu_device = torch.device("cpu") @@ -350,7 +351,7 @@ def downsample_forward(self, _self, num_slices, hidden_states): hidden_states = torch.cat([hidden_states, x], dim=2) hidden_states = hidden_states.to(org_device) - # print("downsample forward done", hidden_states.shape) + # logger.info(f"downsample forward done {hidden_states.shape}") return hidden_states @@ -426,7 +427,7 @@ def forward(*args, **kwargs): self.num_slices = num_slices div = num_slices / (2 ** (len(self.up_blocks) - 1)) - print(f"initial divisor: {div}") + logger.info(f"initial divisor: {div}") if div >= 2: div = int(div) for resnet in self.mid_block.resnets: @@ -436,11 +437,11 @@ def forward(*args, **kwargs): for i, up_block in enumerate(self.up_blocks): if div >= 2: div = int(div) - # print(f"up block: {i} divisor: {div}") + # logger.info(f"up block: {i} divisor: {div}") for resnet in up_block.resnets: resnet.forward = wrapper(resblock_forward, resnet, div) if up_block.upsamplers is not None: - # print("has upsample") + # logger.info("has upsample") for upsample in up_block.upsamplers: upsample.forward = wrapper(self.upsample_forward, upsample, div * 2) div *= 2 @@ -528,7 +529,7 @@ def upsample_forward(self, _self, num_slices, hidden_states, output_size=None): del x hidden_states = torch.cat(sliced, dim=2) - # print("us hidden_states", hidden_states.shape) + # logger.info(f"us hidden_states {hidden_states.shape}") del sliced hidden_states = hidden_states.to(org_device) diff --git a/library/train_util.py b/library/train_util.py index 0f5033413..37eba9d23 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -64,7 +64,8 @@ import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec - +from library.utils import get_my_logger +logger = get_my_logger(__name__) # from library.attention_processors import FlashAttnProcessor # from library.hypernetwork import replace_attentions_for_hypernetwork from library.original_unet import UNet2DConditionModel @@ -211,7 +212,7 @@ def add_if_new_reso(self, reso): self.reso_to_id[reso] = bucket_id self.resos.append(reso) self.buckets.append([]) - # print(reso, bucket_id, len(self.buckets)) + # logger.info(reso, bucket_id, len(self.buckets)) def round_to_steps(self, x): x = int(x + 0.5) @@ -237,7 +238,7 @@ def select_bucket(self, image_width, image_height): scale = reso[0] / image_width resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5)) - # print("use predef", image_width, image_height, reso, resized_size) + # logger.info(f"use predef, {image_width}, {image_height}, {reso}, {resized_size}") else: # 縮小のみを行う if image_width * image_height > self.max_area: @@ -256,21 +257,21 @@ def select_bucket(self, image_width, image_height): b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) ar_height_rounded = b_width_in_hr / b_height_rounded - # print(b_width_rounded, b_height_in_wr, ar_width_rounded) - # print(b_width_in_hr, b_height_rounded, ar_height_rounded) + # logger.info(b_width_rounded, b_height_in_wr, ar_width_rounded) + # logger.info(b_width_in_hr, b_height_rounded, ar_height_rounded) if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5)) else: resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded) - # print(resized_size) + # logger.info(resized_size) else: resized_size = (image_width, image_height) # リサイズは不要 # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする) bucket_width = resized_size[0] - resized_size[0] % self.reso_steps bucket_height = resized_size[1] - resized_size[1] % self.reso_steps - # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height) + # logger.info(f"use arbitrary {image_width}, {image_height}, {resized_size}, {bucket_width}, {bucket_height}") reso = (bucket_width, bucket_height) @@ -749,15 +750,15 @@ def make_buckets(self): bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) min_size and max_size are ignored when enable_bucket is False """ - print("loading image sizes.") + logger.info("loading image sizes.") for info in tqdm(self.image_data.values()): if info.image_size is None: info.image_size = self.get_image_size(info.absolute_path) if self.enable_bucket: - print("make buckets") + logger.info("make buckets") else: - print("prepare dataset") + logger.info("prepare dataset") # bucketを作成し、画像をbucketに振り分ける if self.enable_bucket: @@ -772,7 +773,7 @@ def make_buckets(self): if not self.bucket_no_upscale: self.bucket_manager.make_buckets() else: - print( + logger.warning( "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" ) @@ -783,7 +784,7 @@ def make_buckets(self): image_width, image_height ) - # print(image_info.image_key, image_info.bucket_reso) + # logger.info(image_info.image_key, image_info.bucket_reso) img_ar_errors.append(abs(ar_error)) self.bucket_manager.sort() @@ -801,17 +802,17 @@ def make_buckets(self): # bucket情報を表示、格納する if self.enable_bucket: self.bucket_info = {"buckets": {}} - print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") + logger.info("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): count = len(bucket) if count > 0: self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} - print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") + logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}") img_ar_errors = np.array(img_ar_errors) mean_img_ar_error = np.mean(np.abs(img_ar_errors)) self.bucket_info["mean_img_ar_error"] = mean_img_ar_error - print(f"mean ar error (without repeats): {mean_img_ar_error}") + logger.info(f"mean ar error (without repeats): {mean_img_ar_error}") # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる self.buckets_indices: List(BucketBatchIndex) = [] @@ -831,7 +832,7 @@ def make_buckets(self): # num_of_image_types = len(set(bucket)) # bucket_batch_size = min(self.batch_size, num_of_image_types) # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) - # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count) + # # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count) # for batch_index in range(batch_count): # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) # ↑ここまで @@ -870,7 +871,7 @@ def is_text_encoder_output_cacheable(self): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと - print("caching latents.") + logger.info("caching latents.") image_infos = list(self.image_data.values()) @@ -880,7 +881,7 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # split by resolution batches = [] batch = [] - print("checking cache validity...") + logger.info("checking cache validity...") for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] @@ -917,7 +918,7 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc return # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded - print("caching latents...") + logger.info("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) @@ -931,10 +932,10 @@ def cache_text_encoder_outputs( # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと - print("caching text encoder outputs.") + logger.info("caching text encoder outputs.") image_infos = list(self.image_data.values()) - print("checking cache existence...") + logger.info("checking cache existence...") image_infos_to_cache = [] for info in tqdm(image_infos): # subset = self.image_to_subset[info.image_key] @@ -975,7 +976,7 @@ def cache_text_encoder_outputs( batches.append(batch) # iterate batches: call text encoder and cache outputs for memory or disk - print("caching text encoder outputs...") + logger.info("caching text encoder outputs...") for batch in tqdm(batches): infos, input_ids1, input_ids2 = zip(*batch) input_ids1 = torch.stack(input_ids1, dim=0) @@ -1369,7 +1370,7 @@ def read_caption(img_path, caption_extension): try: lines = f.readlines() except UnicodeDecodeError as e: - print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") raise e assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" caption = lines[0].strip() @@ -1378,11 +1379,11 @@ def read_caption(img_path, caption_extension): def load_dreambooth_dir(subset: DreamBoothSubset): if not os.path.isdir(subset.image_dir): - print(f"not directory: {subset.image_dir}") + logger.error(f"not directory: {subset.image_dir}") return [], [] img_paths = glob_images(subset.image_dir, "*") - print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] @@ -1390,7 +1391,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): for img_path in img_paths: cap_for_img = read_caption(img_path, subset.caption_extension) if cap_for_img is None and subset.class_tokens is None: - print( + logger.warning( f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" ) captions.append("") @@ -1409,36 +1410,36 @@ def load_dreambooth_dir(subset: DreamBoothSubset): number_of_missing_captions_to_show = 5 remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show - print( + logger.warning( f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。" ) for i, missing_caption in enumerate(missing_captions): if i >= number_of_missing_captions_to_show: - print(missing_caption + f"... and {remaining_missing_captions} more") + logger.warning(missing_caption + f"... and {remaining_missing_captions} more") break - print(missing_caption) + logger.warning(missing_caption) return img_paths, captions - print("prepare images.") + logger.info("prepare images.") num_train_images = 0 num_reg_images = 0 reg_infos: List[ImageInfo] = [] for subset in subsets: if subset.num_repeats < 1: - print( + logger.warning( f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" ) continue if subset in self.subsets: - print( + logger.warning( f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" ) continue img_paths, captions = load_dreambooth_dir(subset) if len(img_paths) < 1: - print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") + logger.warning(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") continue if subset.is_reg: @@ -1456,15 +1457,15 @@ def load_dreambooth_dir(subset: DreamBoothSubset): subset.img_count = len(img_paths) self.subsets.append(subset) - print(f"{num_train_images} train images with repeating.") + logger.info(f"{num_train_images} train images with repeating.") self.num_train_images = num_train_images - print(f"{num_reg_images} reg images.") + logger.info(f"{num_reg_images} reg images.") if num_train_images < num_reg_images: - print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") + logger.info("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") if num_reg_images == 0: - print("no regularization images / 正則化画像が見つかりませんでした") + logger.info("no regularization images / 正則化画像が見つかりませんでした") else: # num_repeatsを計算する:どうせ大した数ではないのでループで処理する n = 0 @@ -1508,27 +1509,27 @@ def __init__( for subset in subsets: if subset.num_repeats < 1: - print( + logger.info( f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" ) continue if subset in self.subsets: - print( + logger.info( f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" ) continue # メタデータを読み込む if os.path.exists(subset.metadata_file): - print(f"loading existing metadata: {subset.metadata_file}") + logger.info(f"loading existing metadata: {subset.metadata_file}") with open(subset.metadata_file, "rt", encoding="utf-8") as f: metadata = json.load(f) else: raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") if len(metadata) < 1: - print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") + logger.info(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") continue tags_list = [] @@ -1606,14 +1607,14 @@ def __init__( if not npz_any: use_npz_latents = False - print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") + logger.info(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") elif not npz_all: use_npz_latents = False - print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") + logger.info(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") if flip_aug_in_subset: - print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") + logger.info("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") # else: - # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") + # logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") # check min/max bucket size sizes = set() @@ -1629,7 +1630,7 @@ def __init__( if sizes is None: if use_npz_latents: use_npz_latents = False - print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します") + logger.info(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します") assert ( resolution is not None @@ -1643,8 +1644,8 @@ def __init__( self.bucket_no_upscale = bucket_no_upscale else: if not enable_bucket: - print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") - print("using bucket info in metadata / メタデータ内のbucket情報を使います") + logger.info("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") + logger.info("using bucket info in metadata / メタデータ内のbucket情報を使います") self.enable_bucket = True assert ( @@ -1763,7 +1764,7 @@ def __init__( assert subset is not None, "internal error: subset not found" if not os.path.isdir(subset.conditioning_data_dir): - print(f"not directory: {subset.conditioning_data_dir}") + logger.info(f"not directory: {subset.conditioning_data_dir}") continue img_basename = os.path.basename(info.absolute_path) @@ -1883,14 +1884,14 @@ def enable_XTI(self, *args, **kwargs): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): for i, dataset in enumerate(self.datasets): - print(f"[Dataset {i}]") + logger.info(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True ): for i, dataset in enumerate(self.datasets): - print(f"[Dataset {i}]") + logger.info(f"[Dataset {i}]") dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) def set_caching_mode(self, caching_mode): @@ -1974,12 +1975,12 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli def debug_dataset(train_dataset, show_input_ids=False): - print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します") + logger.info(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") + logger.info("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します") epoch = 1 while True: - print(f"\nepoch: {epoch}") + logger.info(f"\nepoch: {epoch}") steps = (epoch - 1) * len(train_dataset) + 1 indices = list(range(len(train_dataset))) @@ -1989,11 +1990,11 @@ def debug_dataset(train_dataset, show_input_ids=False): for i, idx in enumerate(indices): train_dataset.set_current_epoch(epoch) train_dataset.set_current_step(steps) - print(f"steps: {steps} ({i + 1}/{len(train_dataset)})") + logger.info(f"steps: {steps} ({i + 1}/{len(train_dataset)})") example = train_dataset[idx] if example["latents"] is not None: - print(f"sample has latents from npz file: {example['latents'].size()}") + logger.info(f"sample has latents from npz file: {example['latents'].size()}") for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate( zip( example["image_keys"], @@ -2006,24 +2007,24 @@ def debug_dataset(train_dataset, show_input_ids=False): example["flippeds"], ) ): - print( + logger.info( f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' ) if show_input_ids: - print(f"input ids: {iid}") + logger.info(f"input ids: {iid}") if "input_ids2" in example: - print(f"input ids2: {example['input_ids2'][j]}") + logger.info(f"input ids2: {example['input_ids2'][j]}") if example["images"] is not None: im = example["images"][j] - print(f"image size: {im.size()}") + logger.info(f"image size: {im.size()}") im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c im = im[:, :, ::-1] # RGB -> BGR (OpenCV) if "conditioning_images" in example: cond_img = example["conditioning_images"][j] - print(f"conditioning image size: {cond_img.size()}") + logger.info(f"conditioning image size: {cond_img.size()}") cond_img = ((cond_img.numpy() + 1.0) * 127.5).astype(np.uint8) cond_img = np.transpose(cond_img, (1, 2, 0)) cond_img = cond_img[:, :, ::-1] @@ -2171,12 +2172,12 @@ def trim_and_resize_if_required( if image_width > reso[0]: trim_size = image_width - reso[0] p = trim_size // 2 if not random_crop else random.randint(0, trim_size) - # print("w", trim_size, p) + # logger.info(f"w {trim_size} {p}") image = image[:, p : p + reso[0]] if image_height > reso[1]: trim_size = image_height - reso[1] p = trim_size // 2 if not random_crop else random.randint(0, trim_size) - # print("h", trim_size, p) + # logger.info(f"h {trim_size} {p}) image = image[p : p + reso[1]] # random cropの場合のcropされた値をどうcrop left/topに反映するべきか全くアイデアがない @@ -2414,7 +2415,7 @@ def get_git_revision_hash() -> str: # def replace_unet_cross_attn_to_xformers(): -# print("CrossAttention.forward has been replaced to enable xformers.") +# logger.info("CrossAttention.forward has been replaced to enable xformers.") # try: # import xformers.ops # except ImportError: @@ -2457,10 +2458,10 @@ def get_git_revision_hash() -> str: # diffusers.models.attention.CrossAttention.forward = forward_xformers def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -2468,7 +2469,7 @@ def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdp unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_sdpa(True) @@ -2479,17 +2480,17 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform replace_vae_attn_to_memory_efficient() elif xformers: # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ - print("Use Diffusers xformers for VAE") + logger.info("Use Diffusers xformers for VAE") vae.encoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) vae.decoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) def replace_vae_attn_to_memory_efficient(): - print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") + logger.info("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states): - print("forward_flash_attn") + logger.info("forward_flash_attn") q_bucket_size = 512 k_bucket_size = 1024 @@ -3053,13 +3054,13 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: def verify_training_args(args: argparse.Namespace): if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") + logger.info("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + logger.info("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") if args.cache_latents_to_disk and not args.cache_latents: args.cache_latents = True - print( + logger.info( "cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします" ) @@ -3090,7 +3091,7 @@ def verify_training_args(args: argparse.Namespace): ) if args.zero_terminal_snr and not args.v_parameterization: - print( + logger.info( f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected" + " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります" ) @@ -3254,7 +3255,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar if args.output_config: # check if config file exists if os.path.exists(config_path): - print(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") + logger.info(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") exit(1) # convert args to dictionary @@ -3282,14 +3283,14 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar with open(config_path, "w") as f: toml.dump(args_dict, f) - print(f"Saved config file / 設定ファイルを保存しました: {config_path}") + logger.info(f"Saved config file / 設定ファイルを保存しました: {config_path}") exit(0) if not os.path.exists(config_path): - print(f"{config_path} not found.") + logger.info(f"{config_path} not found.") exit(1) - print(f"Loading settings from {config_path}...") + logger.info(f"Loading settings from {config_path}...") with open(config_path, "r") as f: config_dict = toml.load(f) @@ -3308,7 +3309,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar config_args = argparse.Namespace(**ignore_nesting_dict) args = parser.parse_args(namespace=config_args) args.config_file = os.path.splitext(args.config_file)[0] - print(args.config_file) + logger.info(args.config_file) return args @@ -3323,11 +3324,11 @@ def resume_from_local_or_hf_if_specified(accelerator, args): return if not args.resume_from_huggingface: - print(f"resume training from local state: {args.resume}") + logger.info(f"resume training from local state: {args.resume}") accelerator.load_state(args.resume) return - print(f"resume training from huggingface state: {args.resume}") + logger.info(f"resume training from huggingface state: {args.resume}") repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1] path_in_repo = "/".join(args.resume.split("/")[2:]) revision = None @@ -3339,7 +3340,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args): repo_type = "model" else: path_in_repo, revision, repo_type = divided - print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") + logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") list_files = huggingface_util.list_dir( repo_id=repo_id, @@ -3411,7 +3412,7 @@ def get_optimizer(args, trainable_params): # value = tuple(value) optimizer_kwargs[key] = value - # print("optkwargs:", optimizer_kwargs) + # logger.info(f"optkwargs {optimizer}_{kwargs}") lr = args.learning_rate optimizer = None @@ -3421,7 +3422,7 @@ def get_optimizer(args, trainable_params): import lion_pytorch except ImportError: raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") - print(f"use Lion optimizer | {optimizer_kwargs}") + logger.info(f"use Lion optimizer | {optimizer_kwargs}") optimizer_class = lion_pytorch.Lion optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) @@ -3432,14 +3433,14 @@ def get_optimizer(args, trainable_params): raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") if optimizer_type == "AdamW8bit".lower(): - print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") optimizer_class = bnb.optim.AdamW8bit optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "SGDNesterov8bit".lower(): - print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") if "momentum" not in optimizer_kwargs: - print( + logger.info( f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します" ) optimizer_kwargs["momentum"] = 0.9 @@ -3448,7 +3449,7 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) elif optimizer_type == "Lion8bit".lower(): - print(f"use 8-bit Lion optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit Lion optimizer | {optimizer_kwargs}") try: optimizer_class = bnb.optim.Lion8bit except AttributeError: @@ -3456,7 +3457,7 @@ def get_optimizer(args, trainable_params): "No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください" ) elif optimizer_type == "PagedAdamW8bit".lower(): - print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") try: optimizer_class = bnb.optim.PagedAdamW8bit except AttributeError: @@ -3464,7 +3465,7 @@ def get_optimizer(args, trainable_params): "No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) elif optimizer_type == "PagedLion8bit".lower(): - print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") try: optimizer_class = bnb.optim.PagedLion8bit except AttributeError: @@ -3475,7 +3476,7 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW32bit".lower(): - print(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") + logger.info(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") try: import bitsandbytes as bnb except ImportError: @@ -3489,16 +3490,16 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "SGDNesterov".lower(): - print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") + logger.info(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") if "momentum" not in optimizer_kwargs: - print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します") + logger.info(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します") optimizer_kwargs["momentum"] = 0.9 optimizer_class = torch.optim.SGD optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower(): - # check lr and lr_count, and print warning + # check lr and lr_count, and logger.info warning actual_lr = lr lr_count = 1 if type(trainable_params) == list and type(trainable_params[0]) == dict: @@ -3509,12 +3510,12 @@ def get_optimizer(args, trainable_params): lr_count = len(lrs) if actual_lr <= 0.1: - print( + logger.info( f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}" ) - print("recommend option: lr=1.0 / 推奨は1.0です") + logger.info("recommend option: lr=1.0 / 推奨は1.0です") if lr_count > 1: - print( + logger.info( f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" ) @@ -3530,25 +3531,25 @@ def get_optimizer(args, trainable_params): # set optimizer if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): optimizer_class = experimental.DAdaptAdamPreprint - print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdaGrad".lower(): optimizer_class = dadaptation.DAdaptAdaGrad - print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdam".lower(): optimizer_class = dadaptation.DAdaptAdam - print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdan".lower(): optimizer_class = dadaptation.DAdaptAdan - print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdanIP".lower(): optimizer_class = experimental.DAdaptAdanIP - print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptLion".lower(): optimizer_class = dadaptation.DAdaptLion - print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptSGD".lower(): optimizer_class = dadaptation.DAdaptSGD - print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") else: raise ValueError(f"Unknown optimizer type: {optimizer_type}") @@ -3561,7 +3562,7 @@ def get_optimizer(args, trainable_params): except ImportError: raise ImportError("No Prodigy / Prodigy がインストールされていないようです") - print(f"use Prodigy optimizer | {optimizer_kwargs}") + logger.info(f"use Prodigy optimizer | {optimizer_kwargs}") optimizer_class = prodigyopt.Prodigy optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) @@ -3570,14 +3571,14 @@ def get_optimizer(args, trainable_params): if "relative_step" not in optimizer_kwargs: optimizer_kwargs["relative_step"] = True # default if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): - print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします") + logger.info(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします") optimizer_kwargs["relative_step"] = True - print(f"use Adafactor optimizer | {optimizer_kwargs}") + logger.info(f"use Adafactor optimizer | {optimizer_kwargs}") if optimizer_kwargs["relative_step"]: - print(f"relative_step is true / relative_stepがtrueです") + logger.info(f"relative_step is true / relative_stepがtrueです") if lr != 0.0: - print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") + logger.info(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") args.learning_rate = None # trainable_paramsがgroupだった時の処理:lrを削除する @@ -3589,37 +3590,37 @@ def get_optimizer(args, trainable_params): if has_group_lr: # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない - print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") + logger.warning(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") args.unet_lr = None args.text_encoder_lr = None if args.lr_scheduler != "adafactor": - print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") + logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど lr = None else: if args.max_grad_norm != 0.0: - print( + logger.warning( f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません" ) if args.lr_scheduler != "constant_with_warmup": - print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") + logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: - print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") + logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") optimizer_class = transformers.optimization.Adafactor optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "AdamW".lower(): - print(f"use AdamW optimizer | {optimizer_kwargs}") + logger.info(f"use AdamW optimizer | {optimizer_kwargs}") optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) if optimizer is None: # 任意のoptimizerを使う optimizer_type = args.optimizer_type # lowerでないやつ(微妙) - print(f"use {optimizer_type} | {optimizer_kwargs}") + logger.info(f"use {optimizer_type} | {optimizer_kwargs}") if "." not in optimizer_type: optimizer_module = torch.optim else: @@ -3665,7 +3666,7 @@ def wrap_check_needless_num_warmup_steps(return_vals): # using any lr_scheduler from other library if args.lr_scheduler_type: lr_scheduler_type = args.lr_scheduler_type - print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") + logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") if "." not in lr_scheduler_type: # default to use torch.optim lr_scheduler_module = torch.optim.lr_scheduler else: @@ -3681,7 +3682,7 @@ def wrap_check_needless_num_warmup_steps(return_vals): type(optimizer) == transformers.optimization.Adafactor ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" initial_lr = float(name.split(":")[1]) - # print("adafactor scheduler init lr", initial_lr) + # logger.info(f"adafactor scheduler init lr {initial_lr}") return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) name = SchedulerType(name) @@ -3746,20 +3747,20 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): if support_metadata: if args.in_json is not None and (args.color_aug or args.random_crop): - print( + logger.info( f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます" ) def load_tokenizer(args: argparse.Namespace): - print("prepare tokenizer") + logger.info("prepare tokenizer") original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH tokenizer: CLIPTokenizer = None if args.tokenizer_cache_dir: local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) if os.path.exists(local_tokenizer_path): - print(f"load tokenizer from cache: {local_tokenizer_path}") + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 if tokenizer is None: @@ -3769,10 +3770,10 @@ def load_tokenizer(args: argparse.Namespace): tokenizer = CLIPTokenizer.from_pretrained(original_path) if hasattr(args, "max_token_length") and args.max_token_length is not None: - print(f"update token length: {args.max_token_length}") + logger.info(f"update token length: {args.max_token_length}") if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - print(f"save Tokenizer to cache: {local_tokenizer_path}") + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") tokenizer.save_pretrained(local_tokenizer_path) return tokenizer @@ -3839,17 +3840,17 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une name_or_path = os.path.realpath(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: - print(f"load StableDiffusion checkpoint: {name_or_path}") + logger.info(f"load StableDiffusion checkpoint: {name_or_path}") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint( args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2 ) else: # Diffusers model is loaded to CPU - print(f"load Diffusers pretrained models: {name_or_path}") + logger.info(f"load Diffusers pretrained models: {name_or_path}") try: pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) except EnvironmentError as ex: - print( + logger.error( f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" ) raise ex @@ -3860,7 +3861,7 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une # Diffusers U-Net to original U-Net # TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう - # print(f"unet config: {unet.config}") + # logger.info(f"unet config: {unet.config}") original_unet = UNet2DConditionModel( unet.config.sample_size, unet.config.attention_head_dim, @@ -3870,12 +3871,12 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une ) original_unet.load_state_dict(unet.state_dict()) unet = original_unet - print("U-Net converted to original U-Net") + logger.info("U-Net converted to original U-Net") # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, weight_dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") return text_encoder, vae, unet, load_stable_diffusion_format @@ -3895,7 +3896,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio # load models for each process for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: - print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model( args, @@ -4204,7 +4205,7 @@ def save_sd_model_on_epoch_end_or_stepwise_common( ckpt_name = get_step_ckpt_name(args, ext, global_step) ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"\nsaving checkpoint: {ckpt_file}") + logger.info(f"\nsaving checkpoint: {ckpt_file}") sd_saver(ckpt_file, epoch_no, global_step) if args.huggingface_repo_id is not None: @@ -4219,7 +4220,7 @@ def save_sd_model_on_epoch_end_or_stepwise_common( remove_ckpt_file = os.path.join(args.output_dir, remove_ckpt_name) if os.path.exists(remove_ckpt_file): - print(f"removing old checkpoint: {remove_ckpt_file}") + logger.info(f"removing old checkpoint: {remove_ckpt_file}") os.remove(remove_ckpt_file) else: @@ -4228,7 +4229,7 @@ def save_sd_model_on_epoch_end_or_stepwise_common( else: out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step)) - print(f"\nsaving model: {out_dir}") + logger.info(f"\nsaving model: {out_dir}") diffusers_saver(out_dir) if args.huggingface_repo_id is not None: @@ -4242,7 +4243,7 @@ def save_sd_model_on_epoch_end_or_stepwise_common( remove_out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, remove_no)) if os.path.exists(remove_out_dir): - print(f"removing old model: {remove_out_dir}") + logger.info(f"removing old model: {remove_out_dir}") shutil.rmtree(remove_out_dir) if args.save_state: @@ -4255,13 +4256,13 @@ def save_sd_model_on_epoch_end_or_stepwise_common( def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no): model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) - print(f"\nsaving state at epoch {epoch_no}") + logger.info(f"\nsaving state at epoch {epoch_no}") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - print("uploading state to huggingface.") + logger.info("uploading state to huggingface.") huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs @@ -4269,20 +4270,20 @@ def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, ep remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) if os.path.exists(state_dir_old): - print(f"removing old state: {state_dir_old}") + logger.info(f"removing old state: {state_dir_old}") shutil.rmtree(state_dir_old) def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no): model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) - print(f"\nsaving state at step {step_no}") + logger.info(f"\nsaving state at step {step_no}") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - print("uploading state to huggingface.") + logger.info("uploading state to huggingface.") huggingface_util.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no)) last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps @@ -4294,21 +4295,21 @@ def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_n if remove_step_no > 0: state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no)) if os.path.exists(state_dir_old): - print(f"removing old state: {state_dir_old}") + logger.info(f"removing old state: {state_dir_old}") shutil.rmtree(state_dir_old) def save_state_on_train_end(args: argparse.Namespace, accelerator): model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) - print("\nsaving last state.") + logger.info("\nsaving last state.") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - print("uploading last state to huggingface.") + logger.info("uploading last state to huggingface.") huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name)) @@ -4357,7 +4358,7 @@ def save_sd_model_on_train_end_common( ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") + logger.info(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") sd_saver(ckpt_file, epoch, global_step) if args.huggingface_repo_id is not None: @@ -4366,7 +4367,7 @@ def save_sd_model_on_train_end_common( out_dir = os.path.join(args.output_dir, model_name) os.makedirs(out_dir, exist_ok=True) - print(f"save trained model as Diffusers to {out_dir}") + logger.info(f"save trained model as Diffusers to {out_dir}") diffusers_saver(out_dir) if args.huggingface_repo_id is not None: @@ -4439,9 +4440,9 @@ def sample_images_common( if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch return - print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") + logger.info(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): - print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return org_vae_device = vae.device # CPUにいるはず @@ -4505,7 +4506,7 @@ def sample_images_common( # clip_sample=Trueにする if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") + # logger.info("set clip_sample to True") scheduler.config.clip_sample = True pipeline = pipe_class( @@ -4594,8 +4595,8 @@ def sample_images_common( continue except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(ex) if seed is not None: torch.manual_seed(seed) @@ -4612,12 +4613,12 @@ def sample_images_common( height = max(64, height - height % 8) # round to divisible by 8 width = max(64, width - width % 8) # round to divisible by 8 - print(f"prompt: {prompt}") - print(f"negative_prompt: {negative_prompt}") - print(f"height: {height}") - print(f"width: {width}") - print(f"sample_steps: {sample_steps}") - print(f"scale: {scale}") + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") with accelerator.autocast(): latents = pipeline( prompt=prompt, @@ -4683,7 +4684,7 @@ def __getitem__(self, idx): # convert to tensor temporarily so dataloader will accept it tensor_pil = transforms.functional.pil_to_tensor(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") return None return (tensor_pil, img_path) diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 51f581b29..10650d152 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -2,10 +2,11 @@ import os import torch from safetensors.torch import load_file - +from library.utils import get_my_logger +logger = get_my_logger(__name__) def main(file): - print(f"loading: {file}") + logger.info(f"loading: {file}") if os.path.splitext(file)[1] == ".safetensors": sd = load_file(file) else: @@ -17,16 +18,16 @@ def main(file): for key in keys: if "lora_up" in key or "lora_down" in key: values.append((key, sd[key])) - print(f"number of LoRA modules: {len(values)}") + logger.info(f"number of LoRA modules: {len(values)}") if args.show_all_keys: for key in [k for k in keys if k not in values]: values.append((key, sd[key])) - print(f"number of all modules: {len(values)}") + logger.info(f"number of all modules: {len(values)}") for key, value in values: value = value.to(torch.float32) - print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") + logger.info(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") def setup_parser() -> argparse.ArgumentParser: diff --git a/networks/control_net_lllite.py b/networks/control_net_lllite.py index 4ebfef7a4..4f9a82345 100644 --- a/networks/control_net_lllite.py +++ b/networks/control_net_lllite.py @@ -2,7 +2,8 @@ from typing import Optional, List, Type import torch from library import sdxl_original_unet - +from library.utils import get_my_logger +logger = get_my_logger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied SKIP_INPUT_BLOCKS = False @@ -125,7 +126,7 @@ def set_cond_image(self, cond_image): return # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance - # print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") + # logger.info(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") cx = self.conditioning1(cond_image) if not self.is_conv2d: # reshape / b,c,h,w -> b,h*w,c @@ -155,7 +156,7 @@ def forward(self, x): cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) if self.use_zeros_for_batch_uncond: cx[0::2] = 0.0 # uncond is zero - # print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}") + # logger.info(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}") # downで入力の次元数を削減し、conditioning image embeddingと結合する # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している @@ -286,7 +287,7 @@ def create_modules( # create module instances self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule) - print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") + logger.info(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") def forward(self, x): return x # dummy @@ -319,7 +320,7 @@ def load_weights(self, file): return info def apply_to(self): - print("applying LLLite for U-Net...") + logger.info("applying LLLite for U-Net...") for module in self.unet_modules: module.apply_to() self.add_module(module.lllite_name, module) @@ -374,19 +375,19 @@ def save_weights(self, file, dtype, metadata): # sdxl_original_unet.USE_REENTRANT = False # test shape etc - print("create unet") + logger.info("create unet") unet = sdxl_original_unet.SdxlUNet2DConditionModel() unet.to("cuda").to(torch.float16) - print("create ControlNet-LLLite") + logger.info("create ControlNet-LLLite") control_net = ControlNetLLLite(unet, 32, 64) control_net.apply_to() control_net.to("cuda") - print(control_net) + logger.info(control_net) - # print number of parameters - print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad)) + # logger.info number of parameters + logger.info(f"number of parameters {sum(p.numel() for p in control_net.parameters() if p.requires_grad)}") input() @@ -398,12 +399,12 @@ def save_weights(self, file, dtype, metadata): # # visualize # import torchviz - # print("run visualize") + # logger.info("run visualize") # controlnet.set_control(conditioning_image) # output = unet(x, t, ctx, y) - # print("make_dot") + # logger.info("make_dot") # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) - # print("render") + # logger.info("render") # image.format = "svg" # "png" # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time # input() @@ -414,12 +415,12 @@ def save_weights(self, file, dtype, metadata): scaler = torch.cuda.amp.GradScaler(enabled=True) - print("start training") + logger.info("start training") steps = 10 sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0] for step in range(steps): - print(f"step {step}") + logger.info(f"step {step}") batch_size = 1 conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 @@ -439,7 +440,7 @@ def save_weights(self, file, dtype, metadata): scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) - print(sample_param) + logger.info(f"{sample_param}") # from safetensors.torch import save_file diff --git a/networks/control_net_lllite_for_train.py b/networks/control_net_lllite_for_train.py index 026880015..a202712ea 100644 --- a/networks/control_net_lllite_for_train.py +++ b/networks/control_net_lllite_for_train.py @@ -6,7 +6,8 @@ from typing import Optional, List, Type import torch from library import sdxl_original_unet - +from library.utils import get_my_logger +logger = get_my_logger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied SKIP_INPUT_BLOCKS = False @@ -270,7 +271,7 @@ def apply_to_modules( # create module instances self.lllite_modules = apply_to_modules(self, target_modules) - print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.") + logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.") # def prepare_optimizer_params(self): def prepare_params(self): @@ -281,8 +282,8 @@ def prepare_params(self): train_params.append(p) else: non_train_params.append(p) - print(f"count of trainable parameters: {len(train_params)}") - print(f"count of non-trainable parameters: {len(non_train_params)}") + logger.info(f"count of trainable parameters: {len(train_params)}") + logger.info(f"count of non-trainable parameters: {len(non_train_params)}") for p in non_train_params: p.requires_grad_(False) @@ -388,7 +389,7 @@ def load_lllite_weights(self, file, non_lllite_unet_sd=None): matches = pattern.findall(module_name) if matches is not None: for m in matches: - print(module_name, m) + logger.info(f"{module_name} {m}") module_name = module_name.replace(m, m.replace("_", "@")) module_name = module_name.replace("_", ".") module_name = module_name.replace("@", "_") @@ -407,7 +408,7 @@ def forward(self, x, timesteps=None, context=None, y=None, cond_image=None, **kw def replace_unet_linear_and_conv2d(): - print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net") + logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net") sdxl_original_unet.torch.nn.Linear = LLLiteLinear sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d @@ -419,10 +420,10 @@ def replace_unet_linear_and_conv2d(): replace_unet_linear_and_conv2d() # test shape etc - print("create unet") + logger.info("create unet") unet = SdxlUNet2DConditionModelControlNetLLLite() - print("enable ControlNet-LLLite") + logger.info("enable ControlNet-LLLite") unet.apply_lllite(32, 64, None, False, 1.0) unet.to("cuda") # .to(torch.float16) @@ -439,14 +440,14 @@ def replace_unet_linear_and_conv2d(): # unet_sd[converted_key] = model_sd[key] # info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd) - # print(info) + # logger.info(info) - # print(unet) + # logger.info(unet) - # print number of parameters + # logger.info number of parameters params = unet.prepare_params() - print("number of parameters", sum(p.numel() for p in params)) - # print("type any key to continue") + logger.info(f"number of parameters {sum(p.numel() for p in params)}") + # logger.info("type any key to continue") # input() unet.set_use_memory_efficient_attention(True, False) @@ -455,12 +456,12 @@ def replace_unet_linear_and_conv2d(): # # visualize # import torchviz - # print("run visualize") + # logger.info("run visualize") # controlnet.set_control(conditioning_image) # output = unet(x, t, ctx, y) - # print("make_dot") + # logger.info("make_dot") # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) - # print("render") + # logger.info("render") # image.format = "svg" # "png" # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time # input() @@ -471,13 +472,13 @@ def replace_unet_linear_and_conv2d(): scaler = torch.cuda.amp.GradScaler(enabled=True) - print("start training") + logger.info("start training") steps = 10 batch_size = 1 sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0] for step in range(steps): - print(f"step {step}") + logger.info(f"step {step}") conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 x = torch.randn(batch_size, 4, 128, 128).cuda() @@ -494,9 +495,9 @@ def replace_unet_linear_and_conv2d(): scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) - print(sample_param) + logger.info(sample_param) # from safetensors.torch import save_file - # print("save weights") + # logger.info("save weights") # unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None) diff --git a/networks/dylora.py b/networks/dylora.py index e5a55d198..5781cee51 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -15,7 +15,8 @@ from typing import List, Tuple, Union import torch from torch import nn - +from library.utils import get_my_logger +logger = get_my_logger(__name__) class DyLoRAModule(torch.nn.Module): """ @@ -223,7 +224,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(f"{lora_name} {value.size()} {dim}") # support old LoRA without alpha for key in modules_dim.keys(): @@ -267,11 +268,11 @@ def __init__( self.apply_to_conv = apply_to_conv if modules_dim is not None: - print(f"create LoRA network from weights") + logger.info("create LoRA network from weights") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}") + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}") if self.apply_to_conv: - print(f"apply LoRA to Conv2d with kernel size (3,3).") + logger.info("apply LoRA to Conv2d with kernel size (3,3).") # create module instances def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]: @@ -308,7 +309,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules return loras self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE @@ -316,7 +317,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras = create_modules(True, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") def set_multiplier(self, multiplier): self.multiplier = multiplier @@ -336,12 +337,12 @@ def load_weights(self, file): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -359,12 +360,12 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): apply_unet = True if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -375,7 +376,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - print(f"weights are merged") + logger.info(f"weights are merged") """ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): diff --git a/networks/extract_lora_from_dylora.py b/networks/extract_lora_from_dylora.py index 0abee9836..35747ebfc 100644 --- a/networks/extract_lora_from_dylora.py +++ b/networks/extract_lora_from_dylora.py @@ -10,7 +10,8 @@ from tqdm import tqdm from library import train_util, model_util import numpy as np - +from library.utils import get_my_logger +logger = get_my_logger(__name__) def load_state_dict(file_name): if model_util.is_safetensors(file_name): @@ -40,13 +41,13 @@ def split_lora_model(lora_sd, unit): rank = value.size()[0] if rank > max_rank: max_rank = rank - print(f"Max rank: {max_rank}") + logger.info(f"Max rank: {max_rank}") rank = unit split_models = [] new_alpha = None while rank < max_rank: - print(f"Splitting rank {rank}") + logger.info(f"Splitting rank {rank}") new_sd = {} for key, value in lora_sd.items(): if "lora_down" in key: @@ -57,7 +58,7 @@ def split_lora_model(lora_sd, unit): # なぜかscaleするとおかしくなる…… # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0] # scale = math.sqrt(this_rank / rank) # rank is > unit - # print(key, value.size(), this_rank, rank, value, scale) + # logger.info(key, value.size(), this_rank, rank, value, scale) # new_alpha = value * scale # always same # new_sd[key] = new_alpha new_sd[key] = value @@ -69,10 +70,10 @@ def split_lora_model(lora_sd, unit): def split(args): - print("loading Model...") + logger.info("loading Model...") lora_sd, metadata = load_state_dict(args.model) - print("Splitting Model...") + logger.info("Splitting Model...") original_rank, split_models = split_lora_model(lora_sd, args.unit) comment = metadata.get("ss_training_comment", "") @@ -94,7 +95,7 @@ def split(args): filename, ext = os.path.splitext(args.save_to) model_file_name = filename + f"-{new_rank:04d}{ext}" - print(f"saving model to: {model_file_name}") + logger.info(f"saving model to: {model_file_name}") save_to_file(model_file_name, state_dict, new_metadata) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index dba7cd4e2..792fa3bd8 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -11,7 +11,8 @@ from tqdm import tqdm from library import sai_model_spec, model_util, sdxl_model_util import lora - +from library.utils import get_my_logger +logger = get_my_logger(__name__) CLAMP_QUANTILE = 0.99 MIN_DIFF = 1e-1 @@ -49,20 +50,20 @@ def str_to_dtype(p): # load models if not args.sdxl: - print(f"loading original SD model : {args.model_org}") + logger.info(f"loading original SD model : {args.model_org}") text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) text_encoders_o = [text_encoder_o] - print(f"loading tuned SD model : {args.model_tuned}") + logger.info(f"loading tuned SD model : {args.model_tuned}") text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) text_encoders_t = [text_encoder_t] model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization) else: - print(f"loading original SDXL model : {args.model_org}") + logger.info(f"loading original SDXL model : {args.model_org}") text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu" ) text_encoders_o = [text_encoder_o1, text_encoder_o2] - print(f"loading original SDXL model : {args.model_tuned}") + logger.info(f"loading original SDXL model : {args.model_tuned}") text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu" ) @@ -93,13 +94,13 @@ def str_to_dtype(p): # Text Encoder might be same if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF: text_encoder_different = True - print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}") + logger.warning(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}") diff = diff.float() diffs[lora_name] = diff if not text_encoder_different: - print("Text encoder is same. Extract U-Net only.") + logger.warning("Text encoder is same. Extract U-Net only.") lora_network_o.text_encoder_loras = [] diffs = {} @@ -116,7 +117,7 @@ def str_to_dtype(p): diffs[lora_name] = diff # make LoRA with svd - print("calculating by svd") + logger.info("calculating by svd") lora_weights = {} with torch.no_grad(): for lora_name, mat in tqdm(list(diffs.items())): @@ -131,7 +132,7 @@ def str_to_dtype(p): if args.device: mat = mat.to(args.device) - # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) + # logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim if conv2d: @@ -176,7 +177,7 @@ def str_to_dtype(p): lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict info = lora_network_save.load_state_dict(lora_sd) - print(f"Loading extracted LoRA weights: {info}") + logger.info(f"Loading extracted LoRA weights: {info}") dir_name = os.path.dirname(args.save_to) if dir_name and not os.path.exists(dir_name): @@ -205,7 +206,7 @@ def str_to_dtype(p): metadata.update(sai_metadata) lora_network_save.save_weights(args.save_to, save_dtype, metadata) - print(f"LoRA weights are saved to: {args.save_to}") + logger.info(f"LoRA weights are saved to: {args.save_to}") def setup_parser() -> argparse.ArgumentParser: diff --git a/networks/lora.py b/networks/lora.py index 0c75cd428..a37475d10 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -11,7 +11,8 @@ import numpy as np import torch import re - +from library.utils import get_my_logger +logger = get_my_logger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -46,7 +47,7 @@ def __init__( # if limit_rank: # self.lora_dim = min(lora_dim, in_dim, out_dim) # if self.lora_dim != lora_dim: - # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # else: self.lora_dim = lora_dim @@ -177,7 +178,7 @@ def merge_to(self, sd, dtype, device): else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + self.multiplier * conved * self.scale # set weight to org_module @@ -216,7 +217,7 @@ def set_region(self, region): self.region_mask = None def default_forward(self, x): - # print("default_forward", self.lora_name, x.size()) + # logger.info(f"default_forward {self.lora_name} {x.size()}") return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale def forward(self, x): @@ -245,7 +246,7 @@ def get_mask_for_x(self, x): if mask is None: # raise ValueError(f"mask is None for resolution {area}") # emb_layers in SDXL doesn't have mask - # print(f"mask is None for resolution {area}, {x.size()}") + # logger.info(f"mask is None for resolution {area}, {x.size()}") mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1) return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts if len(x.size()) != 4: @@ -262,7 +263,7 @@ def regional_forward(self, x): # apply mask for LoRA result lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale mask = self.get_mask_for_x(lx) - # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + # logger.info(f"regional {self.lora_name} {self.network.sub_prompt_index} {lx.size()} {mask.size()}") lx = lx * mask x = self.org_forward(x) @@ -291,7 +292,7 @@ def postp_to_q(self, x): if has_real_uncond: query[-self.network.batch_size :] = x[-self.network.batch_size :] - # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) + # logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}") return query def sub_prompt_forward(self, x): @@ -306,7 +307,7 @@ def sub_prompt_forward(self, x): lx = x[emb_idx :: self.network.num_sub_prompts] lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale - # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) + # logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}") x = self.org_forward(x) x[emb_idx :: self.network.num_sub_prompts] += lx @@ -314,7 +315,7 @@ def sub_prompt_forward(self, x): return x def to_out_forward(self, x): - # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) + # logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}") if self.network.is_last_network: masks = [None] * self.network.num_sub_prompts @@ -332,7 +333,7 @@ def to_out_forward(self, x): ) self.network.shared[self.lora_name] = (lx, masks) - # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}") lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) @@ -351,7 +352,7 @@ def to_out_forward(self, x): if has_real_uncond: out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond - # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}") # if num_sub_prompts > num of LoRAs, fill with zero for i in range(len(masks)): if masks[i] is None: @@ -374,7 +375,7 @@ def to_out_forward(self, x): x1 = x1 + lx1 out[self.network.batch_size + i] = x1 - # print("to_out_forward", x.size(), out.size(), has_real_uncond) + # logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}") return out @@ -511,7 +512,7 @@ def parse_floats(s): len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -520,7 +521,7 @@ def parse_floats(s): len(block_alphas) == num_total_blocks ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" else: - print( + logger.warning( f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" ) block_alphas = [network_alpha] * num_total_blocks @@ -540,13 +541,13 @@ def parse_floats(s): else: if conv_alpha is None: conv_alpha = 1.0 - print( + logger.warning( f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" ) conv_block_alphas = [conv_alpha] * num_total_blocks else: if conv_dim is not None: - print( + logger.warning( f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" ) conv_block_dims = [conv_dim] * num_total_blocks @@ -586,7 +587,7 @@ def get_list(name_with_suffix) -> List[float]: elif name == "zeros": return [0.0 + base_lr] * max_len else: - print( + logger.error( "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" % (name) ) @@ -598,14 +599,14 @@ def get_list(name_with_suffix) -> List[float]: up_lr_weight = get_list(up_lr_weight) if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) + logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) up_lr_weight = up_lr_weight[:max_len] down_lr_weight = down_lr_weight[:max_len] if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) if down_lr_weight != None and len(down_lr_weight) < max_len: down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) @@ -613,24 +614,24 @@ def get_list(name_with_suffix) -> List[float]: up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): - print("apply block learning rate / 階層別学習率を適用します。") + logger.info("apply block learning rate / 階層別学習率を適用します。") if down_lr_weight != None: down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] - print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight) + logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") else: - print("down_lr_weight: all 1.0, すべて1.0") + logger.info("down_lr_weight: all 1.0, すべて1.0") if mid_lr_weight != None: mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 - print("mid_lr_weight:", mid_lr_weight) + logger.info(f"mid_lr_weight: {mid_lr_weight}") else: - print("mid_lr_weight: 1.0") + logger.info("mid_lr_weight: 1.0") if up_lr_weight != None: up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] - print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight) + logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") else: - print("up_lr_weight: all 1.0, すべて1.0") + logger.info("up_lr_weight: all 1.0, すべて1.0") return down_lr_weight, mid_lr_weight, up_lr_weight @@ -711,7 +712,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(lora_name, value.size(), dim) # support old LoRA without alpha for key in modules_dim.keys(): @@ -786,20 +787,20 @@ def __init__( self.module_dropout = module_dropout if modules_dim is not None: - print(f"create LoRA network from weights") + logger.info(f"create LoRA network from weights") elif block_dims is not None: - print(f"create LoRA network from block_dims") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") - print(f"block_dims: {block_dims}") - print(f"block_alphas: {block_alphas}") + logger.info(f"create LoRA network from block_dims") + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"block_dims: {block_dims}") + logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: - print(f"conv_block_dims: {conv_block_dims}") - print(f"conv_block_alphas: {conv_block_alphas}") + logger.info(f"conv_block_dims: {conv_block_dims}") + logger.info(f"conv_block_alphas: {conv_block_alphas}") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") if self.conv_lora_dim is not None: - print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules( @@ -884,15 +885,15 @@ def create_modules( for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: index = i + 1 - print(f"create LoRA for Text Encoder {index}:") + logger.info(f"create LoRA for Text Encoder {index}:") else: index = None - print(f"create LoRA for Text Encoder:") + logger.info(f"create LoRA for Text Encoder:") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE @@ -900,15 +901,15 @@ def create_modules( target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: - print( + logger.info( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: - print(f"\t{name}") + logger.info(f"\t{name}") self.up_lr_weight: List[float] = None self.down_lr_weight: List[float] = None @@ -939,12 +940,12 @@ def load_weights(self, file): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -966,12 +967,12 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): apply_unet = True if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -982,7 +983,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - print(f"weights are merged") + logger.info(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない def set_block_lr_weight( @@ -1128,7 +1129,7 @@ def set_current_generation(self, batch_size, num_sub_prompts, width, height, sha device = ref_weight.device def resize_add(mh, mw): - # print(mh, mw, mh * mw) + # logger.info(mh, mw, mh * mw) m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 m = m.to(device, dtype=dtype) mask_dic[mh * mw] = m diff --git a/networks/lora_diffusers.py b/networks/lora_diffusers.py index 47d75ac4d..88f33c0e3 100644 --- a/networks/lora_diffusers.py +++ b/networks/lora_diffusers.py @@ -10,7 +10,8 @@ from tqdm import tqdm from transformers import CLIPTextModel import torch - +from library.utils import get_my_logger +logger = get_my_logger(__name__) def make_unet_conversion_map() -> Dict[str, str]: unet_conversion_map_layer = [] @@ -248,7 +249,7 @@ def create_network_from_weights( elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(f"{lora_name} {value.size()} {dim}") # support old LoRA without alpha for key in modules_dim.keys(): @@ -291,12 +292,12 @@ def __init__( super().__init__() self.multiplier = multiplier - print(f"create LoRA network from weights") + logger.info("create LoRA network from weights") # convert SDXL Stability AI's U-Net modules to Diffusers converted = self.convert_unet_modules(modules_dim, modules_alpha) if converted: - print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)") + logger.info(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)") # create module instances def create_modules( @@ -331,7 +332,7 @@ def create_modules( lora_name = lora_name.replace(".", "_") if lora_name not in modules_dim: - # print(f"skipped {lora_name} (not found in modules_dim)") + # logger.info(f"skipped {lora_name} (not found in modules_dim)") skipped.append(lora_name) continue @@ -362,18 +363,18 @@ def create_modules( text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") if len(skipped_te) > 0: - print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.") + logger.warning(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.") # extend U-Net target modules to include Conv2d 3x3 target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras: List[LoRAModule] self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") if len(skipped_un) > 0: - print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.") + logger.warning(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.") # assertion names = set() @@ -420,11 +421,11 @@ def set_multiplier(self, multiplier): def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") for lora in self.text_encoder_loras: lora.apply_to(multiplier) if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") for lora in self.unet_loras: lora.apply_to(multiplier) @@ -433,16 +434,16 @@ def unapply_to(self): lora.unapply_to() def merge_to(self, multiplier=1.0): - print("merge LoRA weights to original weights") + logger.info("merge LoRA weights to original weights") for lora in tqdm(self.text_encoder_loras + self.unet_loras): lora.merge_to(multiplier) - print(f"weights are merged") + logger.info(f"weights are merged") def restore_from(self, multiplier=1.0): - print("restore LoRA weights from original weights") + logger.info("restore LoRA weights from original weights") for lora in tqdm(self.text_encoder_loras + self.unet_loras): lora.restore_from(multiplier) - print(f"weights are restored") + logger.info(f"weights are restored") def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): # convert SDXL Stability AI's state dict to Diffusers' based state dict @@ -463,7 +464,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): my_state_dict = self.state_dict() for key in state_dict.keys(): if state_dict[key].size() != my_state_dict[key].size(): - # print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}") + # logger.info(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}") state_dict[key] = state_dict[key].view(my_state_dict[key].size()) return super().load_state_dict(state_dict, strict) @@ -490,7 +491,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): image_prefix = args.model_id.replace("/", "_") + "_" # load Diffusers model - print(f"load model from {args.model_id}") + logger.info(f"load model from {args.model_id}") pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline] if args.sdxl: # use_safetensors=True does not work with 0.18.2 @@ -503,7 +504,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder] # load LoRA weights - print(f"load LoRA weights from {args.lora_weights}") + logger.info(f"load LoRA weights from {args.lora_weights}") if os.path.splitext(args.lora_weights)[1] == ".safetensors": from safetensors.torch import load_file @@ -512,10 +513,10 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): lora_sd = torch.load(args.lora_weights) # create by LoRA weights and load weights - print(f"create LoRA network") + logger.info(f"create LoRA network") lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0) - print(f"load LoRA network weights") + logger.info(f"load LoRA network weights") lora_network.load_state_dict(lora_sd) lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this @@ -544,34 +545,34 @@ def seed_everything(seed): random.seed(seed) # create image with original weights - print(f"create image with original weights") + logger.info(f"create image with original weights") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "original.png") # apply LoRA network to the model: slower than merge_to, but can be reverted easily - print(f"apply LoRA network to the model") + logger.info(f"apply LoRA network to the model") lora_network.apply_to(multiplier=1.0) - print(f"create image with applied LoRA") + logger.info(f"create image with applied LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "applied_lora.png") # unapply LoRA network to the model - print(f"unapply LoRA network to the model") + logger.info(f"unapply LoRA network to the model") lora_network.unapply_to() - print(f"create image with unapplied LoRA") + logger.info(f"create image with unapplied LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "unapplied_lora.png") # merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to) - print(f"merge LoRA network to the model") + logger.info(f"merge LoRA network to the model") lora_network.merge_to(multiplier=1.0) - print(f"create image with LoRA") + logger.info(f"create image with LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "merged_lora.png") @@ -579,31 +580,31 @@ def seed_everything(seed): # restore (unmerge) LoRA weights: numerically unstable # マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない # 保存したstate_dictから元の重みを復元するのが確実 - print(f"restore (unmerge) LoRA weights") + logger.info(f"restore (unmerge) LoRA weights") lora_network.restore_from(multiplier=1.0) - print(f"create image without LoRA") + logger.info(f"create image without LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "unmerged_lora.png") # restore original weights - print(f"restore original weights") + logger.info(f"restore original weights") pipe.unet.load_state_dict(org_unet_sd) pipe.text_encoder.load_state_dict(org_text_encoder_sd) if args.sdxl: pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd) - print(f"create image with restored original weights") + logger.info(f"create image with restored original weights") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "restore_original.png") # use convenience function to merge LoRA weights - print(f"merge LoRA weights with convenience function") + logger.info(f"merge LoRA weights with convenience function") merge_lora_weights(pipe, lora_sd, multiplier=1.0) - print(f"create image with merged LoRA weights") + logger.info(f"create image with merged LoRA weights") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "convenience_merged_lora.png") diff --git a/networks/lora_fa.py b/networks/lora_fa.py index a357d7f7f..141f8780a 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -14,7 +14,8 @@ import numpy as np import torch import re - +from library.utils import get_my_logger +logger = get_my_logger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -49,7 +50,7 @@ def __init__( # if limit_rank: # self.lora_dim = min(lora_dim, in_dim, out_dim) # if self.lora_dim != lora_dim: - # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # else: self.lora_dim = lora_dim @@ -197,7 +198,7 @@ def merge_to(self, sd, dtype, device): else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + self.multiplier * conved * self.scale # set weight to org_module @@ -236,7 +237,7 @@ def set_region(self, region): self.region_mask = None def default_forward(self, x): - # print("default_forward", self.lora_name, x.size()) + # logger.info("default_forward", self.lora_name, x.size()) return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale def forward(self, x): @@ -278,7 +279,7 @@ def regional_forward(self, x): # apply mask for LoRA result lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale mask = self.get_mask_for_x(lx) - # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + # logger.info("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) lx = lx * mask x = self.org_forward(x) @@ -307,7 +308,7 @@ def postp_to_q(self, x): if has_real_uncond: query[-self.network.batch_size :] = x[-self.network.batch_size :] - # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) + # logger.info("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) return query def sub_prompt_forward(self, x): @@ -322,7 +323,7 @@ def sub_prompt_forward(self, x): lx = x[emb_idx :: self.network.num_sub_prompts] lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale - # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) + # logger.info("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) x = self.org_forward(x) x[emb_idx :: self.network.num_sub_prompts] += lx @@ -330,7 +331,7 @@ def sub_prompt_forward(self, x): return x def to_out_forward(self, x): - # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) + # logger.info("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) if self.network.is_last_network: masks = [None] * self.network.num_sub_prompts @@ -348,7 +349,7 @@ def to_out_forward(self, x): ) self.network.shared[self.lora_name] = (lx, masks) - # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) @@ -367,7 +368,7 @@ def to_out_forward(self, x): if has_real_uncond: out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond - # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) # for i in range(len(masks)): # if masks[i] is None: # masks[i] = torch.zeros_like(masks[-1]) @@ -389,7 +390,7 @@ def to_out_forward(self, x): x1 = x1 + lx1 out[self.network.batch_size + i] = x1 - # print("to_out_forward", x.size(), out.size(), has_real_uncond) + # logger.info("to_out_forward", x.size(), out.size(), has_real_uncond) return out @@ -526,7 +527,7 @@ def parse_floats(s): len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -535,7 +536,7 @@ def parse_floats(s): len(block_alphas) == num_total_blocks ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" else: - print( + logger.warning( f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" ) block_alphas = [network_alpha] * num_total_blocks @@ -555,13 +556,13 @@ def parse_floats(s): else: if conv_alpha is None: conv_alpha = 1.0 - print( + logger.warning( f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" ) conv_block_alphas = [conv_alpha] * num_total_blocks else: if conv_dim is not None: - print( + logger.warning( f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" ) conv_block_dims = [conv_dim] * num_total_blocks @@ -601,7 +602,7 @@ def get_list(name_with_suffix) -> List[float]: elif name == "zeros": return [0.0 + base_lr] * max_len else: - print( + logger.error( "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" % (name) ) @@ -613,14 +614,14 @@ def get_list(name_with_suffix) -> List[float]: up_lr_weight = get_list(up_lr_weight) if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) + logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) up_lr_weight = up_lr_weight[:max_len] down_lr_weight = down_lr_weight[:max_len] if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) if down_lr_weight != None and len(down_lr_weight) < max_len: down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) @@ -628,24 +629,24 @@ def get_list(name_with_suffix) -> List[float]: up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): - print("apply block learning rate / 階層別学習率を適用します。") + logger.info("apply block learning rate / 階層別学習率を適用します。") if down_lr_weight != None: down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] - print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight) + logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") else: - print("down_lr_weight: all 1.0, すべて1.0") + logger.info("down_lr_weight: all 1.0, すべて1.0") if mid_lr_weight != None: mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 - print("mid_lr_weight:", mid_lr_weight) + logger.info(f"mid_lr_weight: {mid_lr_weight}") else: - print("mid_lr_weight: 1.0") + logger.info("mid_lr_weight: 1.0") if up_lr_weight != None: up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] - print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight) + logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") else: - print("up_lr_weight: all 1.0, すべて1.0") + logger.info("up_lr_weight: all 1.0, すべて1.0") return down_lr_weight, mid_lr_weight, up_lr_weight @@ -726,7 +727,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(lora_name, value.size(), dim) # support old LoRA without alpha for key in modules_dim.keys(): @@ -801,20 +802,20 @@ def __init__( self.module_dropout = module_dropout if modules_dim is not None: - print(f"create LoRA network from weights") + logger.info(f"create LoRA network from weights") elif block_dims is not None: - print(f"create LoRA network from block_dims") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") - print(f"block_dims: {block_dims}") - print(f"block_alphas: {block_alphas}") + logger.info(f"create LoRA network from block_dims") + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"block_dims: {block_dims}") + logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: - print(f"conv_block_dims: {conv_block_dims}") - print(f"conv_block_alphas: {conv_block_alphas}") + logger.info(f"conv_block_dims: {conv_block_dims}") + logger.info(f"conv_block_alphas: {conv_block_alphas}") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") if self.conv_lora_dim is not None: - print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules( @@ -899,15 +900,15 @@ def create_modules( for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: index = i + 1 - print(f"create LoRA for Text Encoder {index}:") + logger.info(f"create LoRA for Text Encoder {index}:") else: index = None - print(f"create LoRA for Text Encoder:") + logger.info(f"create LoRA for Text Encoder:") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE @@ -915,15 +916,15 @@ def create_modules( target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: - print( + logger.info( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: - print(f"\t{name}") + logger.info(f"\t{name}") self.up_lr_weight: List[float] = None self.down_lr_weight: List[float] = None @@ -954,12 +955,12 @@ def load_weights(self, file): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -981,12 +982,12 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): apply_unet = True if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -997,7 +998,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - print(f"weights are merged") + logger.info(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない def set_block_lr_weight( @@ -1144,7 +1145,7 @@ def set_current_generation(self, batch_size, num_sub_prompts, width, height, sha device = ref_weight.device def resize_add(mh, mw): - # print(mh, mw, mh * mw) + # logger.info(mh, mw, mh * mw) m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 m = m.to(device, dtype=dtype) mask_dic[mh * mw] = m diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py index 0dc066fd1..ccdcaddb9 100644 --- a/networks/lora_interrogator.py +++ b/networks/lora_interrogator.py @@ -9,6 +9,8 @@ import library.model_util as model_util import lora +from library.utils import get_my_logger +logger = get_my_logger(__name__) TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う @@ -20,12 +22,12 @@ def interrogate(args): weights_dtype = torch.float16 # いろいろ準備する - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") args.pretrained_model_name_or_path = args.sd_model args.vae = None text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE) - print(f"loading LoRA: {args.model}") + logger.info(f"loading LoRA: {args.model}") network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい @@ -35,11 +37,11 @@ def interrogate(args): has_te_weight = True break if not has_te_weight: - print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") + logger.info("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") return del vae - print("loading tokenizer") + logger.info("loading tokenizer") if args.v2: tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") else: @@ -53,7 +55,7 @@ def interrogate(args): # トークンをひとつひとつ当たっていく token_id_start = 0 token_id_end = max(tokenizer.all_special_ids) - print(f"interrogate tokens are: {token_id_start} to {token_id_end}") + logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}") def get_all_embeddings(text_encoder): embs = [] @@ -79,24 +81,24 @@ def get_all_embeddings(text_encoder): embs.extend(encoder_hidden_states) return torch.stack(embs) - print("get original text encoder embeddings.") + logger.info("get original text encoder embeddings.") orig_embs = get_all_embeddings(text_encoder) network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0) info = network.load_state_dict(weights_sd, strict=False) - print(f"Loading LoRA weights: {info}") + logger.info(f"Loading LoRA weights: {info}") network.to(DEVICE, dtype=weights_dtype) network.eval() del unet - print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") - print("get text encoder embeddings with lora.") + logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") + logger.info("get text encoder embeddings with lora.") lora_embs = get_all_embeddings(text_encoder) # 比べる:とりあえず単純に差分の絶対値で - print("comparing...") + logger.info("comparing...") diffs = {} for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))): diff = torch.mean(torch.abs(orig_emb - lora_emb)) diff --git a/networks/merge_lora.py b/networks/merge_lora.py index 71492621e..a6ae053bd 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -7,7 +7,8 @@ from library import sai_model_spec, train_util import library.model_util as model_util import lora - +from library.utils import get_my_logger +logger = get_my_logger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -61,10 +62,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): name_to_module[lora_name] = child_module for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -73,10 +74,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): # find original module for this lora module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") + logger.info(f"no module found for LoRA weight: {key}") continue module = name_to_module[module_name] - # print(f"apply {key} to {module}") + # logger.info(f"apply {key} to {module}") down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -104,7 +105,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + ratio * conved * scale module.weight = torch.nn.Parameter(weight) @@ -118,7 +119,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): v2 = None base_model = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) if lora_metadata is not None: @@ -151,10 +152,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") # merge - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if "alpha" in key: continue @@ -196,8 +197,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): merged_sd[key_down] = merged_sd[key_down][perm] merged_sd[key_up] = merged_sd[key_up][:,perm] - print("merged model") - print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") # check all dims are same dims_list = list(set(base_dims.values())) @@ -239,7 +240,7 @@ def str_to_dtype(p): save_dtype = merge_dtype if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) @@ -264,18 +265,18 @@ def str_to_dtype(p): ) if args.v2: # TODO read sai modelspec - print( + logger.info( "Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) - print(f"saving SD model to: {args.save_to}") + logger.info(f"saving SD model to: {args.save_to}") model_util.save_stable_diffusion_checkpoint( args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae ) else: state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) - print(f"calculating hashes and creating metadata...") + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -289,12 +290,12 @@ def str_to_dtype(p): ) if v2: # TODO read sai modelspec - print( + logger.info( "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) metadata.update(sai_metadata) - print(f"saving model to: {args.save_to}") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) diff --git a/networks/merge_lora_old.py b/networks/merge_lora_old.py index ffd6b2b40..e62531301 100644 --- a/networks/merge_lora_old.py +++ b/networks/merge_lora_old.py @@ -6,7 +6,8 @@ from safetensors.torch import load_file, save_file import library.model_util as model_util import lora - +from library.utils import get_my_logger +logger = get_my_logger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == '.safetensors': @@ -54,10 +55,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): name_to_module[lora_name] = child_module for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -66,10 +67,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): # find original module for this lora module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") + logger.info(f"no module found for LoRA weight: {key}") continue module = name_to_module[module_name] - # print(f"apply {key} to {module}") + # logger.info(f"apply {key} to {module}") down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -96,10 +97,10 @@ def merge_lora_models(models, ratios, merge_dtype): alpha = None dim = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if 'alpha' in key: if key in merged_sd: @@ -117,7 +118,7 @@ def merge_lora_models(models, ratios, merge_dtype): dim = lora_sd[key].size()[0] merged_sd[key] = lora_sd[key] * ratio - print(f"dim (rank): {dim}, alpha: {alpha}") + logger.info(f"dim (rank): {dim}, alpha: {alpha}") if alpha is None: alpha = dim @@ -142,19 +143,19 @@ def str_to_dtype(p): save_dtype = merge_dtype if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) - print(f"\nsaving SD model to: {args.save_to}") + logger.info(f"\nsaving SD model to: {args.save_to}") model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) else: state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) - print(f"\nsaving model to: {args.save_to}") + logger.info(f"\nsaving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype) diff --git a/networks/oft.py b/networks/oft.py index 1d088f877..68843bce9 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -8,7 +8,8 @@ import numpy as np import torch import re - +from library.utils import get_my_logger +logger = get_my_logger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -237,7 +238,7 @@ def __init__( self.dim = dim self.alpha = alpha - print( + logger.info( f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}" ) @@ -258,7 +259,7 @@ def create_modules( if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv): oft_name = prefix + "." + name + "." + child_name oft_name = oft_name.replace(".", "_") - # print(oft_name) + # logger.info(oft_name) oft = module_class( oft_name, @@ -279,7 +280,7 @@ def create_modules( target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) - print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.") + logger.info(f"create OFT for U-Net: {len(self.unet_ofts)} modules.") # assertion names = set() @@ -316,7 +317,7 @@ def is_mergeable(self): # TODO refactor to common function with apply_to def merge_to(self, text_encoder, unet, weights_sd, dtype, device): - print("enable OFT for U-Net") + logger.info("enable OFT for U-Net") for oft in self.unet_ofts: sd_for_lora = {} @@ -326,7 +327,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): oft.load_state_dict(sd_for_lora, False) oft.merge_to() - print(f"weights are merged") + logger.info(f"weights are merged") # 二つのText Encoderに別々の学習率を設定できるようにするといいかも def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): @@ -338,11 +339,11 @@ def enumerate_params(ofts): for oft in ofts: params.extend(oft.parameters()) - # print num of params + # logger.info num of params num_params = 0 for p in params: num_params += p.numel() - print(f"OFT params: {num_params}") + logger.info(f"OFT params: {num_params}") return params param_data = {"params": enumerate_params(self.unet_ofts)} diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 03fc545e7..60cdc008f 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -8,6 +8,8 @@ from tqdm import tqdm from library import train_util, model_util import numpy as np +from library.utils import get_my_logger +logger = get_my_logger(__name__) MIN_SV = 1e-6 @@ -206,7 +208,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn scale = network_alpha/network_dim if dynamic_method: - print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") + logger.info(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") lora_down_weight = None lora_up_weight = None @@ -275,10 +277,10 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn del param_dict if verbose: - print(verbose_str) + logger.info(verbose_str) - print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") - print("resizing complete") + logger.info(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") + logger.info("resizing complete") return o_lora_sd, network_dim, new_alpha @@ -304,10 +306,10 @@ def str_to_dtype(p): if save_dtype is None: save_dtype = merge_dtype - print("loading Model...") + logger.info("loading Model...") lora_sd, metadata = load_state_dict(args.model, merge_dtype) - print("Resizing Lora...") + logger.info("Resizing Lora...") state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) # update metadata @@ -329,7 +331,7 @@ def str_to_dtype(p): metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash - print(f"saving model to: {args.save_to}") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index c513eb59f..f4960ed8b 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -8,7 +8,8 @@ from library import sai_model_spec, sdxl_model_util, train_util import library.model_util as model_util import lora - +from library.utils import get_my_logger +logger = get_my_logger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -66,10 +67,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ name_to_module[lora_name] = child_module for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in tqdm(lora_sd.keys()): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -78,10 +79,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ # find original module for this lora module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") + logger.info(f"no module found for LoRA weight: {key}") continue module = name_to_module[module_name] - # print(f"apply {key} to {module}") + # logger.info(f"apply {key} to {module}") down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -92,7 +93,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ # W <- W + U * D weight = module.weight - # print(module_name, down_weight.size(), up_weight.size()) + # logger.info(module_name, down_weight.size(), up_weight.size()) if len(weight.size()) == 2: # linear weight = weight + ratio * (up_weight @ down_weight) * scale @@ -107,7 +108,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + ratio * conved * scale module.weight = torch.nn.Parameter(weight) @@ -121,7 +122,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): v2 = None base_model = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) if lora_metadata is not None: @@ -154,10 +155,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") # merge - print(f"merging...") + logger.info(f"merging...") for key in tqdm(lora_sd.keys()): if "alpha" in key: continue @@ -200,8 +201,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): merged_sd[key_down] = merged_sd[key_down][perm] merged_sd[key_up] = merged_sd[key_up][:,perm] - print("merged model") - print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") # check all dims are same dims_list = list(set(base_dims.values())) @@ -243,7 +244,7 @@ def str_to_dtype(p): save_dtype = merge_dtype if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") ( text_model1, @@ -265,14 +266,14 @@ def str_to_dtype(p): None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from ) - print(f"saving SD model to: {args.save_to}") + logger.info(f"saving SD model to: {args.save_to}") sdxl_model_util.save_stable_diffusion_checkpoint( args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype ) else: state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) - print(f"calculating hashes and creating metadata...") + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -286,7 +287,7 @@ def str_to_dtype(p): ) metadata.update(sai_metadata) - print(f"saving model to: {args.save_to}") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 16e813b36..5d32bf0a2 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -8,7 +8,8 @@ from library import sai_model_spec, train_util import library.model_util as model_util import lora - +from library.utils import get_my_logger +logger = get_my_logger(__name__) CLAMP_QUANTILE = 0.99 @@ -41,12 +42,12 @@ def save_to_file(file_name, state_dict, dtype, metadata): def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): - print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") + logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} v2 = None base_model = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) if lora_metadata is not None: @@ -56,7 +57,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) # merge - print(f"merging...") + logger.info(f"merging...") for key in tqdm(list(lora_sd.keys())): if "lora_down" not in key: continue @@ -73,7 +74,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty out_dim = up_weight.size()[0] conv2d = len(down_weight.size()) == 4 kernel_size = None if not conv2d else down_weight.size()[2:4] - # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) + # logger.info(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) # make original weight if not exist if lora_module_name not in merged_sd: @@ -110,7 +111,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty merged_sd[lora_module_name] = weight # extract from merged weights - print("extract new lora...") + logger.info("extract new lora...") merged_lora_sd = {} with torch.no_grad(): for lora_module_name, mat in tqdm(list(merged_sd.items())): @@ -188,7 +189,7 @@ def str_to_dtype(p): args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype ) - print(f"calculating hashes and creating metadata...") + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -203,12 +204,12 @@ def str_to_dtype(p): ) if v2: # TODO read sai modelspec - print( + logger.info( "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) metadata.update(sai_metadata) - print(f"saving model to: {args.save_to}") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, save_dtype, metadata) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index c31ae0072..8569d9076 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -60,6 +60,8 @@ from library.sdxl_original_unet import SdxlUNet2DConditionModel from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite +from library.utils import get_my_logger +logger = get_my_logger(__name__) # scheduler: SCHEDULER_LINEAR_START = 0.00085 @@ -81,12 +83,12 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -94,7 +96,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_memory_efficient_attention(False, False) unet.set_use_sdpa(True) @@ -111,7 +113,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform def replace_vae_attn_to_memory_efficient(): - print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states, **kwargs): @@ -167,7 +169,7 @@ def forward_flash_attn_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_xformers(): - print("VAE: Attention.forward has been replaced to xformers") + logger.info("VAE: Attention.forward has been replaced to xformers") import xformers.ops def forward_xformers(self, hidden_states, **kwargs): @@ -223,7 +225,7 @@ def forward_xformers_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_sdpa(): - print("VAE: Attention.forward has been replaced to sdpa") + logger.info("VAE: Attention.forward has been replaced to sdpa") def forward_sdpa(self, hidden_states, **kwargs): residual = hidden_states @@ -357,7 +359,7 @@ def get_token_replacer(self, tokenizer): token_replacements = self.token_replacements_list[tokenizer_index] def replace_tokens(tokens): - # print("replace_tokens", tokens, "=>", token_replacements) + # logger.info("replace_tokens", tokens, "=>", token_replacements) if isinstance(tokens, torch.Tensor): tokens = tokens.tolist() @@ -449,7 +451,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") + logger.info(f"negative_scale is ignored if guidance scalle <= 1.0") negative_scale = None # get unconditional embeddings for classifier free guidance @@ -552,7 +554,7 @@ def __call__( text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt if init_image is not None and self.clip_vision_model is not None: - print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") + logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) @@ -717,7 +719,7 @@ def __call__( if not enabled or ratio >= 1.0: continue if ratio < i / len(timesteps): - print(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") control_net.set_cond_image(None) each_control_net_enabled[j] = False @@ -937,7 +939,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L if word.strip() == "BREAK": # pad until next multiple of tokenizer's max token length pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) - print(f"BREAK pad_len: {pad_len}") + logger.info(f"BREAK pad_len: {pad_len}") for i in range(pad_len): # v2のときEOSをつけるべきかどうかわからないぜ # if i == 0: @@ -967,7 +969,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L tokens.append(text_token) weights.append(text_weight) if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights @@ -1240,7 +1242,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): elif len(count_range) == 2: count_range = [int(count_range[0]), int(count_range[1])] else: - print(f"invalid count range: {count_range}") + logger.warning(f"invalid count range: {count_range}") count_range = [1, 1] if count_range[0] > count_range[1]: count_range = [count_range[1], count_range[0]] @@ -1310,7 +1312,7 @@ def replacer(): # def load_clip_l14_336(dtype): -# print(f"loading CLIP: {CLIP_ID_L14_336}") +# logger.info(f"loading CLIP: {CLIP_ID_L14_336}") # text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) # return text_encoder @@ -1379,7 +1381,7 @@ def main(args): replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) # tokenizerを読み込む - print("loading tokenizer") + logger.info("loading tokenizer") tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) # schedulerを用意する @@ -1453,7 +1455,7 @@ def reset_sampler_noises(self, noises): self.sampler_noises = noises def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + # logger.info("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): noise = self.sampler_noises[self.sampler_noise_index] if shape != noise.shape: @@ -1462,7 +1464,7 @@ def randn(self, shape, device=None, dtype=None, layout=None, generator=None): noise = None if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) self.sampler_noise_index += 1 @@ -1494,7 +1496,7 @@ def __getattr__(self, item): # ↓以下は結局PipeでFalseに設定されるので意味がなかった # # clip_sample=Trueにする # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") + # logger.info("set clip_sample to True") # scheduler.config.clip_sample = True # deviceを決定する @@ -1523,7 +1525,7 @@ def __getattr__(self, item): vae_dtype = dtype if args.no_half_vae: - print("set vae_dtype to float32") + logger.info("set vae_dtype to float32") vae_dtype = torch.float32 vae.to(vae_dtype).to(device) @@ -1544,10 +1546,10 @@ def __getattr__(self, item): network_merge = args.network_merge_n_models else: network_merge = 0 - print(f"network_merge: {network_merge}") + logger.info(f"network_merge: {network_merge}") for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) + logger.info(f"import network module: {network_module}") imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] @@ -1565,7 +1567,7 @@ def __getattr__(self, item): raise ValueError("No weight. Weight is required.") network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + logger.info(f"load network weights from: {network_weight}") if model_util.is_safetensors(network_weight) and args.network_show_meta: from safetensors.torch import safe_open @@ -1573,7 +1575,7 @@ def __getattr__(self, item): with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + logger.info(f"metadata for: {network_weight}: {metadata}") network, weights_sd = imported_module.create_network_from_weights( network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs @@ -1583,20 +1585,20 @@ def __getattr__(self, item): mergeable = network.is_mergeable() if network_merge and not mergeable: - print("network is not mergiable. ignore merge option.") + logger.warning("network is not mergiable. ignore merge option.") if not mergeable or i >= network_merge: # not merging network.apply_to([text_encoder1, text_encoder2], unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - print(f"weights are loaded: {info}") + logger.info(f"weights are loaded: {info}") if args.opt_channels_last: network.to(memory_format=torch.channels_last) network.to(dtype).to(device) if network_pre_calc: - print("backup original weights") + logger.info("backup original weights") network.backup_weights() networks.append(network) @@ -1610,7 +1612,7 @@ def __getattr__(self, item): # upscalerの指定があれば取得する upscaler = None if args.highres_fix_upscaler: - print("import upscaler module:", args.highres_fix_upscaler) + logger.info(f"import upscaler module: {args.highres_fix_upscaler}") imported_module = importlib.import_module(args.highres_fix_upscaler) us_kwargs = {} @@ -1619,7 +1621,7 @@ def __getattr__(self, item): key, value = net_arg.split("=") us_kwargs[key] = value - print("create upscaler") + logger.info("create upscaler") upscaler = imported_module.create_upscaler(**us_kwargs) upscaler.to(dtype).to(device) @@ -1636,7 +1638,7 @@ def __getattr__(self, item): # control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) if args.control_net_lllite_models: for i, model_file in enumerate(args.control_net_lllite_models): - print(f"loading ControlNet-LLLite: {model_file}") + logger.info(f"loading ControlNet-LLLite: {model_file}") from safetensors.torch import load_file @@ -1667,7 +1669,7 @@ def __getattr__(self, item): control_nets.append((control_net, ratio)) if args.opt_channels_last: - print(f"set optimizing: channels last") + logger.info(f"set optimizing: channels last") text_encoder1.to(memory_format=torch.channels_last) text_encoder2.to(memory_format=torch.channels_last) vae.to(memory_format=torch.channels_last) @@ -1691,7 +1693,7 @@ def __getattr__(self, item): args.clip_skip, ) pipe.set_control_nets(control_nets) - print("pipeline is ready.") + logger.info("pipeline is ready.") if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() @@ -1729,7 +1731,7 @@ def __getattr__(self, item): token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") + logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") assert ( min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 ), f"token ids1 is not ordered" @@ -1759,7 +1761,7 @@ def __getattr__(self, item): # promptを取得する if args.from_file is not None: - print(f"reading prompts from {args.from_file}") + logger.info(f"reading prompts from {args.from_file}") with open(args.from_file, "r", encoding="utf-8") as f: prompt_list = f.read().splitlines() prompt_list = [d for d in prompt_list if len(d.strip()) > 0] @@ -1788,7 +1790,7 @@ def load_images(path): for p in paths: image = Image.open(p) if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") + logger.info(f"convert image to RGB from {image.mode}: {p}") image = image.convert("RGB") images.append(image) @@ -1804,14 +1806,14 @@ def resize_images(imgs, size): return resized if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") + logger.info(f"load image for img2img: {args.image_path}") init_images = load_images(args.image_path) assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") + logger.info(f"loaded {len(init_images)} images for img2img") # CLIP Vision if args.clip_vision_strength is not None: - print(f"load CLIP Vision model: {CLIP_VISION_MODEL}") + logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}") vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) vision_model.to(device, dtype) processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) @@ -1819,22 +1821,22 @@ def resize_images(imgs, size): pipe.clip_vision_model = vision_model pipe.clip_vision_processor = processor pipe.clip_vision_strength = args.clip_vision_strength - print(f"CLIP Vision model loaded.") + logger.info(f"CLIP Vision model loaded.") else: init_images = None if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") + logger.info(f"load mask for inpainting: {args.mask_path}") mask_images = load_images(args.mask_path) assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") + logger.info(f"loaded {len(mask_images)} mask images for inpainting") else: mask_images = None # promptがないとき、画像のPngInfoから取得する if init_images is not None and len(prompt_list) == 0 and not args.interactive: - print("get prompts from images' metadata") + logger.info("get prompts from images' metadata") for img in init_images: if "prompt" in img.text: prompt = img.text["prompt"] @@ -1863,17 +1865,17 @@ def resize_images(imgs, size): h = int(h * args.highres_fix_scale + 0.5) if init_images is not None: - print(f"resize img2img source images to {w}*{h}") + logger.info(f"resize img2img source images to {w}*{h}") init_images = resize_images(init_images, (w, h)) if mask_images is not None: - print(f"resize img2img mask images to {w}*{h}") + logger.info(f"resize img2img mask images to {w}*{h}") mask_images = resize_images(mask_images, (w, h)) regional_network = False if networks and mask_images: # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 regional_network = True - print("use mask as region") + logger.info("use mask as region") size = None for i, network in enumerate(networks): @@ -1898,14 +1900,14 @@ def resize_images(imgs, size): prev_image = None # for VGG16 guided if args.guide_image_path is not None: - print(f"load image for ControlNet guidance: {args.guide_image_path}") + logger.info(f"load image for ControlNet guidance: {args.guide_image_path}") guide_images = [] for p in args.guide_image_path: guide_images.extend(load_images(p)) - print(f"loaded {len(guide_images)} guide images for guidance") + logger.info(f"loaded {len(guide_images)} guide images for guidance") if len(guide_images) == 0: - print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + logger.warning(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") guide_images = None else: guide_images = None @@ -1931,7 +1933,7 @@ def resize_images(imgs, size): max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") + logger.info(f"iteration {gen_iter+1}/{args.n_iter}") iter_seed = random.randint(0, 0x7FFFFFFF) # バッチ処理の関数 @@ -1943,7 +1945,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - print("process 1st stage") + logger.info("process 1st stage") batch_1st = [] for _, base, ext in batch: @@ -1988,7 +1990,7 @@ def scale_and_round(x): images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") + logger.info("process 2nd stage") width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height if upscaler: @@ -2154,7 +2156,7 @@ def scale_and_round(x): n.restore_weights() for n in networks: n.pre_calculation() - print("pre-calculation... done") + logger.info("pre-calculation... done") images = pipe( prompts, @@ -2233,7 +2235,7 @@ def scale_and_round(x): cv2.waitKey() cv2.destroyAllWindows() except ImportError: - print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + logger.error("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") return images @@ -2246,7 +2248,7 @@ def scale_and_round(x): # interactive valid = False while not valid: - print("\nType prompt:") + logger.info("\nType prompt:") try: raw_prompt = input() except EOFError: @@ -2288,74 +2290,74 @@ def scale_and_round(x): prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") for parg in prompt_args[1:]: try: m = re.match(r"w (\d+)", parg, re.IGNORECASE) if m: width = int(m.group(1)) - print(f"width: {width}") + logger.info(f"width: {width}") continue m = re.match(r"h (\d+)", parg, re.IGNORECASE) if m: height = int(m.group(1)) - print(f"height: {height}") + logger.info(f"height: {height}") continue m = re.match(r"ow (\d+)", parg, re.IGNORECASE) if m: original_width = int(m.group(1)) - print(f"original width: {original_width}") + logger.info(f"original width: {original_width}") continue m = re.match(r"oh (\d+)", parg, re.IGNORECASE) if m: original_height = int(m.group(1)) - print(f"original height: {original_height}") + logger.info(f"original height: {original_height}") continue m = re.match(r"nw (\d+)", parg, re.IGNORECASE) if m: original_width_negative = int(m.group(1)) - print(f"original width negative: {original_width_negative}") + logger.info(f"original width negative: {original_width_negative}") continue m = re.match(r"nh (\d+)", parg, re.IGNORECASE) if m: original_height_negative = int(m.group(1)) - print(f"original height negative: {original_height_negative}") + logger.info(f"original height negative: {original_height_negative}") continue m = re.match(r"ct (\d+)", parg, re.IGNORECASE) if m: crop_top = int(m.group(1)) - print(f"crop top: {crop_top}") + logger.info(f"crop top: {crop_top}") continue m = re.match(r"cl (\d+)", parg, re.IGNORECASE) if m: crop_left = int(m.group(1)) - print(f"crop left: {crop_left}") + logger.info(f"crop left: {crop_left}") continue m = re.match(r"s (\d+)", parg, re.IGNORECASE) if m: # steps steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") + logger.info(f"steps: {steps}") continue m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) if m: # seed seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") + logger.info(f"seeds: {seeds}") continue m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) if m: # scale scale = float(m.group(1)) - print(f"scale: {scale}") + logger.info(f"scale: {scale}") continue m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) @@ -2364,25 +2366,25 @@ def scale_and_round(x): negative_scale = None else: negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") + logger.info(f"negative scale: {negative_scale}") continue m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) if m: # strength strength = float(m.group(1)) - print(f"strength: {strength}") + logger.info(f"strength: {strength}") continue m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") + logger.info(f"negative prompt: {negative_prompt}") continue m = re.match(r"c (.+)", parg, re.IGNORECASE) if m: # clip prompt clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") + logger.info(f"clip prompt: {clip_prompt}") continue m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) @@ -2390,12 +2392,12 @@ def scale_and_round(x): network_muls = [float(v) for v in m.group(1).split(",")] while len(network_muls) < len(networks): network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") + logger.info(f"network mul: {network_muls}") continue except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") # prepare seed if seeds is not None: # given in prompt @@ -2407,7 +2409,7 @@ def scale_and_round(x): if len(predefined_seeds) > 0: seed = predefined_seeds.pop(0) else: - print("predefined seeds are exhausted") + logger.error("predefined seeds are exhausted") seed = None elif args.iter_same_seed: seeds = iter_seed @@ -2417,7 +2419,7 @@ def scale_and_round(x): if seed is None: seed = random.randint(0, 0x7FFFFFFF) if args.interactive: - print(f"seed: {seed}") + logger.info(f"seed: {seed}") # prepare init image, guide image and mask init_image = mask_image = guide_image = None @@ -2433,7 +2435,7 @@ def scale_and_round(x): width = width - width % 32 height = height - height % 32 if width != init_image.size[0] or height != init_image.size[1]: - print( + logger.warning( f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" ) @@ -2493,7 +2495,7 @@ def scale_and_round(x): process_batch(batch_data, highres_fix) batch_data.clear() - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 45b9edd65..13a5be402 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -25,6 +25,8 @@ from library import model_util, sdxl_model_util import networks.lora as lora +from library.utils import get_my_logger +logger = get_my_logger(__name__) # scheduler: このあたりの設定はSD1/2と同じでいいらしい # scheduler: The settings around here seem to be the same as SD1/2 @@ -142,7 +144,7 @@ def get_timestep_embedding(x, outdim): vae_dtype = DTYPE if DTYPE == torch.float16: - print("use float32 for vae") + logger.info("use float32 for vae") vae_dtype = torch.float32 vae.to(DEVICE, dtype=vae_dtype) vae.eval() @@ -189,7 +191,7 @@ def generate_image(prompt, prompt2, negative_prompt, seed=None): emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256) - # print("emb1", emb1.shape) + # logger.info("emb1", emb1.shape) c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE) uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right @@ -219,7 +221,7 @@ def call_text_encoder(text, text2): enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True) text_embedding2_penu = enc_out["hidden_states"][-2] - # print("hidden_states2", text_embedding2_penu.shape) + # logger.info("hidden_states2", text_embedding2_penu.shape) text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion # 連結して終了 concat and finish @@ -228,7 +230,7 @@ def call_text_encoder(text, text2): # cond c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2) - # print(c_ctx.shape, c_ctx_p.shape, c_vector.shape) + # logger.info(c_ctx.shape, c_ctx_p.shape, c_vector.shape) c_vector = torch.cat([c_ctx_pool, c_vector], dim=1) # uncond @@ -325,4 +327,4 @@ def call_text_encoder(text, text2): seed = int(seed) generate_image(prompt, prompt2, negative_prompt, seed) - print("Done!") + logger.info("Done!") diff --git a/sdxl_train.py b/sdxl_train.py index f067acd59..9764a7c46 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -22,6 +22,8 @@ from library import sdxl_model_util import library.train_util as train_util +from library.utils import get_my_logger +logger = get_my_logger(__name__) import library.config_util as config_util import library.sdxl_train_util as sdxl_train_util from library.config_util import ( @@ -130,18 +132,18 @@ def train(args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -152,7 +154,7 @@ def train(args): ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -182,7 +184,7 @@ def train(args): train_util.debug_dataset(train_dataset_group, True) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" ) return @@ -198,7 +200,7 @@ def train(args): ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -523,7 +525,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # print("text encoder outputs verified") + # logger.info("text encoder outputs verified") # get size embeddings orig_size = batch["original_sizes_hw"] @@ -711,7 +713,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): logit_scale, ckpt_info, ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 54abf697c..43d325f14 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -47,7 +47,8 @@ apply_debiased_estimation, ) import networks.control_net_lllite_for_train as control_net_lllite_for_train - +from library.utils import get_my_logger +logger = get_my_logger(__name__) # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): @@ -80,11 +81,11 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if use_user_config: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -116,7 +117,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -126,7 +127,7 @@ def train(args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" else: - print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + logger.warning("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") if args.cache_text_encoder_outputs: assert ( @@ -134,7 +135,7 @@ def train(args): ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -233,8 +234,8 @@ def train(args): accelerator.print("prepare optimizer, data loader etc.") trainable_params = list(unet.prepare_params()) - print(f"trainable params count: {len(trainable_params)}") - print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + logger.info(f"trainable params count: {len(trainable_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -329,7 +330,7 @@ def train(args): accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -551,7 +552,7 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index f00f10eaa..83d6def4e 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -43,7 +43,8 @@ apply_debiased_estimation, ) import networks.control_net_lllite as control_net_lllite - +from library.utils import get_my_logger +logger = get_my_logger(__name__) # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): @@ -76,11 +77,11 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if use_user_config: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -112,7 +113,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -122,7 +123,7 @@ def train(args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" else: - print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + logger.warning("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") if args.cache_text_encoder_outputs: assert ( @@ -130,7 +131,7 @@ def train(args): ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -201,8 +202,8 @@ def train(args): accelerator.print("prepare optimizer, data loader etc.") trainable_params = list(network.prepare_optimizer_params()) - print(f"trainable params count: {len(trainable_params)}") - print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + logger.info(f"trainable params count: {len(trainable_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -302,7 +303,7 @@ def train(args): accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -521,7 +522,7 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 199c4e032..34a6401e0 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -9,7 +9,8 @@ pass from library import sdxl_model_util, sdxl_train_util, train_util import train_network - +from library.utils import get_my_logger +logger = get_my_logger(__name__) class SdxlNetworkTrainer(train_network.NetworkTrainer): def __init__(self): @@ -62,7 +63,7 @@ def cache_text_encoder_outputs_if_needed( if args.cache_text_encoder_outputs: if not args.lowram: # メモリ消費を減らす - print("move vae and unet to cpu to save memory") + logger.info("move vae and unet to cpu to save memory") org_vae_device = vae.device org_unet_device = unet.device vae.to("cpu") @@ -87,7 +88,7 @@ def cache_text_encoder_outputs_if_needed( torch.cuda.empty_cache() if not args.lowram: - print("move vae and unet back to original device") + logger.info("move vae and unet back to original device") vae.to(org_vae_device) unet.to(org_unet_device) else: @@ -144,7 +145,7 @@ def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, wei # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # print("text encoder outputs verified") + # logger.info("text encoder outputs verified") return encoder_hidden_states1, encoder_hidden_states2, pool2 diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 17916ef70..b2faf5b5b 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -16,7 +16,8 @@ ConfigSanitizer, BlueprintGenerator, ) - +from library.utils import get_my_logger +logger = get_my_logger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: train_util.prepare_dataset_args(args, True) @@ -41,18 +42,18 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -63,7 +64,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -90,7 +91,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -98,7 +99,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む - print("load model") + logger.info("load model") if args.sdxl: (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) else: @@ -152,7 +153,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.skip_existing: if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): - print(f"Skipping {image_info.latents_npz} because it already exists.") + logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") continue image_infos.append(image_info) diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 7d9b13d68..0064b8c96 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -16,7 +16,8 @@ ConfigSanitizer, BlueprintGenerator, ) - +from library.utils import get_my_logger +logger = get_my_logger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: train_util.prepare_dataset_args(args, True) @@ -48,18 +49,18 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -70,7 +71,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -95,14 +96,14 @@ def cache_to_disk(args: argparse.Namespace) -> None: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, _ = train_util.prepare_dtype(args) # モデルを読み込む - print("load model") + logger.info("load model") if args.sdxl: (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) text_encoders = [text_encoder1, text_encoder2] @@ -147,7 +148,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.skip_existing: if os.path.exists(image_info.text_encoder_outputs_npz): - print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") + logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") continue image_info.input_ids1 = input_ids1 diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py index b9365b519..d545babe1 100644 --- a/tools/convert_diffusers20_original_sd.py +++ b/tools/convert_diffusers20_original_sd.py @@ -6,7 +6,8 @@ from diffusers import StableDiffusionPipeline import library.model_util as model_util - +from library.utils import get_my_logger +logger = get_my_logger(__name__) def convert(args): # 引数を確認する @@ -30,7 +31,7 @@ def convert(args): # モデルを読み込む msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) - print(f"loading {msg}: {args.model_to_load}") + logger.info(f"loading {msg}: {args.model_to_load}") if is_load_ckpt: v2_model = args.v2 @@ -46,26 +47,26 @@ def convert(args): if args.v1 == args.v2: # 自動判定する v2_model = unet.config.cross_attention_dim == 1024 - print("checking model version: model is " + ("v2" if v2_model else "v1")) + logger.info("checking model version: model is " + ("v2" if v2_model else "v1")) else: v2_model = not args.v1 # 変換して保存する msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" - print(f"converting and saving as {msg}: {args.model_to_save}") + logger.info(f"converting and saving as {msg}: {args.model_to_save}") if is_save_ckpt: original_model = args.model_to_load if is_load_ckpt else None key_count = model_util.save_stable_diffusion_checkpoint( v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae ) - print(f"model saved. total converted state_dict keys: {key_count}") + logger.info(f"model saved. total converted state_dict keys: {key_count}") else: - print(f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}") + logger.info(f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}") model_util.save_diffusers_checkpoint( v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors ) - print(f"model saved.") + logger.info(f"model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py index 68dec6cae..ad1f786a4 100644 --- a/tools/detect_face_rotate.py +++ b/tools/detect_face_rotate.py @@ -15,6 +15,8 @@ from anime_face_detector import create_detector from tqdm import tqdm import numpy as np +from library.utils import get_my_logger +logger = get_my_logger(__name__) KP_REYE = 11 KP_LEYE = 19 @@ -24,7 +26,7 @@ def detect_faces(detector, image, min_size): preds = detector(image) # bgr - # print(len(preds)) + # logger.info(len(preds)) faces = [] for pred in preds: @@ -78,7 +80,7 @@ def process(args): assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません" # アニメ顔検出モデルを読み込む - print("loading face detector.") + logger.info("loading face detector.") detector = create_detector('yolov3') # cropの引数を解析する @@ -97,7 +99,7 @@ def process(args): crop_h_ratio, crop_v_ratio = [float(t) for t in tokens] # 画像を処理する - print("processing.") + logger.info("processing.") output_extension = ".png" os.makedirs(args.dst_dir, exist_ok=True) @@ -111,7 +113,7 @@ def process(args): if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) if image.shape[2] == 4: - print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") + logger.warning(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい h, w = image.shape[:2] @@ -144,11 +146,11 @@ def process(args): # 顔サイズを基準にリサイズする scale = args.resize_face_size / face_size if scale < cur_crop_width / w: - print( + logger.warning( f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") scale = cur_crop_width / w if scale < cur_crop_height / h: - print( + logger.warning( f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") scale = cur_crop_height / h elif crop_h_ratio is not None: @@ -157,10 +159,10 @@ def process(args): else: # 切り出しサイズ指定あり if w < cur_crop_width: - print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") + logger.warning(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") scale = cur_crop_width / w if h < cur_crop_height: - print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") + logger.warning(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") scale = cur_crop_height / h if args.resize_fit: scale = max(cur_crop_width / w, cur_crop_height / h) @@ -198,7 +200,7 @@ def process(args): face_img = face_img[y:y + cur_crop_height] # # debug - # print(path, cx, cy, angle) + # logger.info(path, cx, cy, angle) # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8)) # cv2.imshow("image", crp) # if cv2.waitKey() == 27: diff --git a/tools/latent_upscaler.py b/tools/latent_upscaler.py index ab1fa3390..8b09f02e1 100644 --- a/tools/latent_upscaler.py +++ b/tools/latent_upscaler.py @@ -14,7 +14,8 @@ from torch import nn from tqdm import tqdm from PIL import Image - +from library.utils import get_my_logger +logger = get_my_logger(__name__) class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1): @@ -216,7 +217,7 @@ def upscale( upsampled_images = upsampled_images / 127.5 - 1.0 # convert upsample images to latents with batch size - # print("Encoding upsampled (LANCZOS4) images...") + # logger.info("Encoding upsampled (LANCZOS4) images...") upsampled_latents = [] for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)): batch = upsampled_images[i : i + vae_batch_size].to(vae.device) @@ -227,7 +228,7 @@ def upscale( upsampled_latents = torch.cat(upsampled_latents, dim=0) # upscale (refine) latents with this model with batch size - print("Upscaling latents...") + logger.info("Upscaling latents...") upscaled_latents = [] for i in range(0, upsampled_latents.shape[0], batch_size): with torch.no_grad(): @@ -242,7 +243,7 @@ def create_upscaler(**kwargs): weights = kwargs["weights"] model = Upscaler() - print(f"Loading weights from {weights}...") + logger.info(f"Loading weights from {weights}...") if os.path.splitext(weights)[1] == ".safetensors": from safetensors.torch import load_file @@ -261,14 +262,14 @@ def upscale_images(args: argparse.Namespace): # load VAE with Diffusers assert args.vae_path is not None, "VAE path is required" - print(f"Loading VAE from {args.vae_path}...") + logger.info(f"Loading VAE from {args.vae_path}...") vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae") vae.to(DEVICE, dtype=us_dtype) # prepare model - print("Preparing model...") + logger.info("Preparing model...") upscaler: Upscaler = create_upscaler(weights=args.weights) - # print("Loading weights from", args.weights) + # logger.info("Loading weights from", args.weights) # upscaler.load_state_dict(torch.load(args.weights)) upscaler.eval() upscaler.to(DEVICE, dtype=us_dtype) @@ -303,14 +304,14 @@ def upscale_images(args: argparse.Namespace): image_debug.save(dest_file_name) # upscale - print("Upscaling...") + logger.info("Upscaling...") upscaled_latents = upscaler.upscale( vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size ) upscaled_latents /= 0.18215 # decode with batch - print("Decoding...") + logger.info("Decoding...") upscaled_images = [] for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)): with torch.no_grad(): diff --git a/tools/merge_models.py b/tools/merge_models.py index 391bfe677..975f0413d 100644 --- a/tools/merge_models.py +++ b/tools/merge_models.py @@ -5,7 +5,8 @@ from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm - +from library.utils import get_my_logger +logger = get_my_logger(__name__) def is_unet_key(key): # VAE or TextEncoder, the last one is for SDXL @@ -45,10 +46,10 @@ def merge(args): # check if all models are safetensors for model in args.models: if not model.endswith("safetensors"): - print(f"Model {model} is not a safetensors model") + logger.info(f"Model {model} is not a safetensors model") exit() if not os.path.isfile(model): - print(f"Model {model} does not exist") + logger.info(f"Model {model} does not exist") exit() assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models" @@ -65,7 +66,7 @@ def merge(args): if merged_sd is None: # load first model - print(f"Loading model {model}, ratio = {ratio}...") + logger.info(f"Loading model {model}, ratio = {ratio}...") merged_sd = {} with safe_open(model, framework="pt", device=args.device) as f: for key in tqdm(f.keys()): @@ -81,11 +82,11 @@ def merge(args): value = ratio * value.to(dtype) # first model's value * ratio merged_sd[key] = value - print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) + logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) continue # load other models - print(f"Loading model {model}, ratio = {ratio}...") + logger.info(f"Loading model {model}, ratio = {ratio}...") with safe_open(model, framework="pt", device=args.device) as f: model_keys = f.keys() @@ -93,7 +94,7 @@ def merge(args): _, new_key = replace_text_encoder_key(key) if new_key not in merged_sd: if args.show_skipped and new_key not in first_model_keys: - print(f"Skip: {new_key}") + logger.info(f"Skip: {new_key}") continue value = f.get_tensor(key) @@ -104,7 +105,7 @@ def merge(args): for key in merged_sd.keys(): if key in model_keys: continue - print(f"Key {key} not in model {model}, use first model's value") + logger.warning(f"Key {key} not in model {model}, use first model's value") if key in supplementary_key_ratios: supplementary_key_ratios[key] += ratio else: @@ -112,7 +113,7 @@ def merge(args): # add supplementary keys' value (including VAE and TextEncoder) if len(supplementary_key_ratios) > 0: - print("add first model's value") + logger.info("add first model's value") with safe_open(args.models[0], framework="pt", device=args.device) as f: for key in tqdm(f.keys()): _, new_key = replace_text_encoder_key(key) @@ -120,7 +121,7 @@ def merge(args): continue if is_unet_key(new_key): # not VAE or TextEncoder - print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") + logger.warning(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") value = f.get_tensor(key) # original key @@ -134,7 +135,7 @@ def merge(args): if not output_file.endswith(".safetensors"): output_file = output_file + ".safetensors" - print(f"Saving to {output_file}...") + logger.info(f"Saving to {output_file}...") # convert to save_dtype for k in merged_sd.keys(): @@ -142,7 +143,7 @@ def merge(args): save_file(merged_sd, output_file) - print("Done!") + logger.info("Done!") if __name__ == "__main__": diff --git a/tools/original_control_net.py b/tools/original_control_net.py index cd47bd76a..2aa0a2d90 100644 --- a/tools/original_control_net.py +++ b/tools/original_control_net.py @@ -7,7 +7,8 @@ from library.original_unet import UNet2DConditionModel, SampleOutput import library.model_util as model_util - +from library.utils import get_my_logger +logger = get_my_logger(__name__) class ControlNetInfo(NamedTuple): unet: Any @@ -51,7 +52,7 @@ def load_control_net(v2, unet, model): # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む # state dictを読み込む - print(f"ControlNet: loading control SD model : {model}") + logger.info(f"ControlNet: loading control SD model : {model}") if model_util.is_safetensors(model): ctrl_sd_sd = load_file(model) @@ -61,7 +62,7 @@ def load_control_net(v2, unet, model): # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む is_difference = "difference" in ctrl_sd_sd - print("ControlNet: loading difference:", is_difference) + logger.info(f"ControlNet: loading difference: {is_difference}") # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく # またTransfer Controlの元weightとなる @@ -89,13 +90,13 @@ def load_control_net(v2, unet, model): # ControlNetのU-Netを作成する ctrl_unet = UNet2DConditionModel(**unet_config) info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) - print("ControlNet: loading Control U-Net:", info) + logger.info(f"ControlNet: loading Control U-Net: {info}") # U-Net以外のControlNetを作成する # TODO support middle only ctrl_net = ControlNet() info = ctrl_net.load_state_dict(zero_conv_sd) - print("ControlNet: loading ControlNet:", info) + logger.info("ControlNet: loading ControlNet: {info}") ctrl_unet.to(unet.device, dtype=unet.dtype) ctrl_net.to(unet.device, dtype=unet.dtype) @@ -117,7 +118,7 @@ def canny(img): return canny - print("Unsupported prep type:", prep_type) + logger.info(f"Unsupported prep type: {prep_type}") return None @@ -174,7 +175,7 @@ def call_unet_and_control_net( cnet_idx = step % cnet_cnt cnet_info = control_nets[cnet_idx] - # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) + # logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) if cnet_info.ratio < current_ratio: return original_unet(sample, timestep, encoder_hidden_states) @@ -192,7 +193,7 @@ def call_unet_and_control_net( # ControlNet cnet_outs_list = [] for i, cnet_info in enumerate(control_nets): - # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) + # logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) if cnet_info.ratio < current_ratio: continue guided_hint = guided_hints[i] @@ -232,7 +233,7 @@ def unet_forward( upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - print("Forward upsample size to force interpolation output size.") + logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # 1. time diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index 2d3224c4e..979f7f141 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -6,7 +6,8 @@ import math from PIL import Image import numpy as np - +from library.utils import get_my_logger +logger = get_my_logger(__name__) def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): # Split the max_resolution string by "," and strip any whitespaces @@ -83,7 +84,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi image.save(os.path.join(dst_img_folder, new_filename), quality=100) proc = "Resized" if current_pixels > max_pixels else "Saved" - print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") + logger.info(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") # If other files with same basename, copy them with resolution suffix if copy_associated_files: @@ -94,7 +95,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi continue for max_resolution in max_resolutions: new_asoc_file = base + '+' + max_resolution + ext - print(f"Copy {asoc_file} as {new_asoc_file}") + logger.info(f"Copy {asoc_file} as {new_asoc_file}") shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file)) diff --git a/tools/show_metadata.py b/tools/show_metadata.py index 92ca7b1c8..784cdf3ad 100644 --- a/tools/show_metadata.py +++ b/tools/show_metadata.py @@ -1,6 +1,8 @@ import json import argparse from safetensors import safe_open +from library.utils import get_my_logger +logger = get_my_logger(__name__) parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) @@ -10,10 +12,10 @@ metadata = f.metadata() if metadata is None: - print("No metadata found") + logger.error("No metadata found") else: # metadata is json dict, but not pretty printed # sort by key and pretty print print(json.dumps(metadata, indent=4, sort_keys=True)) - \ No newline at end of file + diff --git a/train_controlnet.py b/train_controlnet.py index bbd915cb3..f0195c1f8 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -37,7 +37,8 @@ pyramid_noise_like, apply_noise_offset, ) - +from library.utils import get_my_logger +logger = get_my_logger(__name__) # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): @@ -71,11 +72,11 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if use_user_config: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -105,7 +106,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -116,7 +117,7 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -312,7 +313,7 @@ def train(args): accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -560,7 +561,7 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, controlnet, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/train_db.py b/train_db.py index 59a124a26..cca5e4176 100644 --- a/train_db.py +++ b/train_db.py @@ -37,6 +37,8 @@ scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, ) +from library.utils import get_my_logger +logger = get_my_logger(__name__) # perlin_noise, @@ -56,11 +58,11 @@ def train(args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -95,13 +97,13 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") if args.gradient_accumulation_steps > 1: - print( + logger.warning( f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong" ) - print( + logger.warning( f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" ) @@ -447,7 +449,7 @@ def train(args): train_util.save_sd_model_on_train_end( args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/train_network.py b/train_network.py index d50916b74..5585bfe77 100644 --- a/train_network.py +++ b/train_network.py @@ -45,7 +45,8 @@ add_v_prediction_like_loss, apply_debiased_estimation, ) - +from library.utils import get_my_logger +logger = get_my_logger(__name__) class NetworkTrainer: def __init__(self): @@ -152,18 +153,18 @@ def train(self, args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if use_user_config: - print(f"Loading dataset config from {args.dataset_config}") + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -174,7 +175,7 @@ def train(self, args): ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -203,7 +204,7 @@ def train(self, args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -216,7 +217,7 @@ def train(self, args): self.assert_extra_args(args, train_dataset_group) # acceleratorを準備する - print("preparing accelerator") + logger.info("preparing accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -307,7 +308,7 @@ def train(self, args): if hasattr(network, "prepare_network"): network.prepare_network(args) if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): - print( + logger.warning( "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません" ) args.scale_weight_norms = False @@ -917,7 +918,7 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 6b6e7f5a0..36b6121de 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -34,6 +34,8 @@ add_v_prediction_like_loss, apply_debiased_estimation, ) +from library.utils import get_my_logger +logger = get_my_logger(__name__) imagenet_templates_small = [ "a photo of a {}", @@ -180,7 +182,7 @@ def train(self, args): tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -290,7 +292,7 @@ def train(self, args): ] } else: - print("Train with captions.") + logger.info("Train with captions.") user_config = { "datasets": [ { @@ -724,7 +726,7 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 8dd5c672f..d9faddfdb 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -38,6 +38,8 @@ ) import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI +from library.utils import get_my_logger +logger = get_my_logger(__name__) imagenet_templates_small = [ "a photo of a {}", @@ -101,7 +103,7 @@ def train(args): train_util.prepare_dataset_args(args, True) if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None: - print( + logger.warning( "sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません" ) assert ( @@ -116,7 +118,7 @@ def train(args): tokenizer = train_util.load_tokenizer(args) # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -129,7 +131,7 @@ def train(args): if args.init_word is not None: init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: - print( + logger.warning( f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" ) else: @@ -143,7 +145,7 @@ def train(args): ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"tokens are added: {token_ids}") + logger.info(f"tokens are added: {token_ids}") assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" @@ -171,7 +173,7 @@ def train(args): tokenizer.add_tokens(token_strings_XTI) token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI) - print(f"tokens are added (XTI): {token_ids_XTI}") + logger.info(f"tokens are added (XTI): {token_ids_XTI}") # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) @@ -180,7 +182,7 @@ def train(args): if init_token_ids is not None: for i, token_id in enumerate(token_ids_XTI): token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]] - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + # logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) # load weights if args.weights is not None: @@ -188,22 +190,22 @@ def train(args): assert len(token_ids) == len( embeddings ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" - # print(token_ids, embeddings.size()) + # logger.info(token_ids, embeddings.size()) for token_id, embedding in zip(token_ids_XTI, embeddings): token_embeds[token_id] = embedding - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - print(f"weighs loaded") + # logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + logger.info(f"weighs loaded") - print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.info( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -211,14 +213,14 @@ def train(args): else: use_dreambooth_method = args.in_json is None if use_dreambooth_method: - print("Use DreamBooth method.") + logger.info("Use DreamBooth method.") user_config = { "datasets": [ {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} ] } else: - print("Train with captions.") + logger.info("Train with captions.") user_config = { "datasets": [ { @@ -242,7 +244,7 @@ def train(args): # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: - print(f"use template for training captions. is object: {args.use_object_template}") + logger.info(f"use template for training captions. is object: {args.use_object_template}") templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small replace_to = " ".join(token_strings) captions = [] @@ -266,7 +268,7 @@ def train(args): train_util.debug_dataset(train_dataset_group, show_input_ids=True) return if len(train_dataset_group) == 0: - print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + logger.error("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") return if cache_latents: @@ -299,7 +301,7 @@ def train(args): text_encoder.gradient_checkpointing_enable() # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") + logger.info("prepare optimizer, data loader etc.") trainable_params = text_encoder.get_input_embeddings().parameters() _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -320,7 +322,7 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + logger.info(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -337,7 +339,7 @@ def train(args): text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] - # print(len(index_no_updates), torch.sum(index_no_updates)) + # logger.info(len(index_no_updates), torch.sum(index_no_updates)) orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder @@ -375,15 +377,15 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + logger.info("running training / 学習開始") + logger.info(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + logger.info(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + logger.info(f" num epochs / epoch数: {num_train_epochs}") + logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}") + logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + logger.info(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 @@ -406,7 +408,7 @@ def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"\nsaving checkpoint: {ckpt_file}") + logger.info(f"\nsaving checkpoint: {ckpt_file}") save_weights(ckpt_file, embs, save_dtype) if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) @@ -414,12 +416,12 @@ def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): def remove_model(old_ckpt_name): old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") + logger.info(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) # training loop for epoch in range(num_train_epochs): - print(f"\nepoch {epoch+1}/{num_train_epochs}") + logger.info(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 text_encoder.train() @@ -589,7 +591,7 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def save_weights(file, updated_embs, save_dtype): From e50f67f1bee633271c2a56a01bf335cd0d314355 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 17:26:29 +0900 Subject: [PATCH 3/8] Fix log level --- finetune/merge_captions_to_metadata.py | 2 +- finetune/tag_images_by_wd14_tagger.py | 2 +- library/train_util.py | 46 +++++++++++++------------- networks/lora.py | 2 +- networks/lora_fa.py | 2 +- networks/lora_interrogator.py | 2 +- networks/merge_lora.py | 4 +-- networks/svd_merge_lora.py | 2 +- 8 files changed, 31 insertions(+), 31 deletions(-) diff --git a/finetune/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py index 07c929a9b..42e915309 100644 --- a/finetune/merge_captions_to_metadata.py +++ b/finetune/merge_captions_to_metadata.py @@ -21,7 +21,7 @@ def main(args): if args.in_json is not None: logger.info(f"loading existing metadata: {args.in_json}") metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - logger.info("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") + logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") else: logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") metadata = {} diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index fd98c296c..9c71fa16e 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -123,7 +123,7 @@ def main(args): if args.batch_size != batch_size and type(batch_size) != str: # some rebatch model may use 'N' as dynamic axes - logger.info( + logger.warning( f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" ) args.batch_size = batch_size diff --git a/library/train_util.py b/library/train_util.py index 37eba9d23..6a67ce73f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1379,7 +1379,7 @@ def read_caption(img_path, caption_extension): def load_dreambooth_dir(subset: DreamBoothSubset): if not os.path.isdir(subset.image_dir): - logger.error(f"not directory: {subset.image_dir}") + logger.warning(f"not directory: {subset.image_dir}") return [], [] img_paths = glob_images(subset.image_dir, "*") @@ -1462,10 +1462,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.info(f"{num_reg_images} reg images.") if num_train_images < num_reg_images: - logger.info("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") + logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") if num_reg_images == 0: - logger.info("no regularization images / 正則化画像が見つかりませんでした") + logger.warning("no regularization images / 正則化画像が見つかりませんでした") else: # num_repeatsを計算する:どうせ大した数ではないのでループで処理する n = 0 @@ -1509,13 +1509,13 @@ def __init__( for subset in subsets: if subset.num_repeats < 1: - logger.info( + logger.warning( f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" ) continue if subset in self.subsets: - logger.info( + logger.warning( f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" ) continue @@ -1529,7 +1529,7 @@ def __init__( raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") if len(metadata) < 1: - logger.info(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") + logger.warning(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") continue tags_list = [] @@ -1607,12 +1607,12 @@ def __init__( if not npz_any: use_npz_latents = False - logger.info(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") + logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") elif not npz_all: use_npz_latents = False - logger.info(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") + logger.warning(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") if flip_aug_in_subset: - logger.info("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") + logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") # else: # logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") @@ -1630,7 +1630,7 @@ def __init__( if sizes is None: if use_npz_latents: use_npz_latents = False - logger.info(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します") + logger.warning(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します") assert ( resolution is not None @@ -1764,7 +1764,7 @@ def __init__( assert subset is not None, "internal error: subset not found" if not os.path.isdir(subset.conditioning_data_dir): - logger.info(f"not directory: {subset.conditioning_data_dir}") + logger.warning(f"not directory: {subset.conditioning_data_dir}") continue img_basename = os.path.basename(info.absolute_path) @@ -3054,13 +3054,13 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: def verify_training_args(args: argparse.Namespace): if args.v_parameterization and not args.v2: - logger.info("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") + logger.warning("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: - logger.info("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") if args.cache_latents_to_disk and not args.cache_latents: args.cache_latents = True - logger.info( + logger.warning( "cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします" ) @@ -3091,7 +3091,7 @@ def verify_training_args(args: argparse.Namespace): ) if args.zero_terminal_snr and not args.v_parameterization: - logger.info( + logger.warning( f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected" + " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります" ) @@ -3255,7 +3255,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar if args.output_config: # check if config file exists if os.path.exists(config_path): - logger.info(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") + logger.error(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") exit(1) # convert args to dictionary @@ -3440,7 +3440,7 @@ def get_optimizer(args, trainable_params): elif optimizer_type == "SGDNesterov8bit".lower(): logger.info(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") if "momentum" not in optimizer_kwargs: - logger.info( + logger.warning( f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します" ) optimizer_kwargs["momentum"] = 0.9 @@ -3510,12 +3510,12 @@ def get_optimizer(args, trainable_params): lr_count = len(lrs) if actual_lr <= 0.1: - logger.info( + logger.warning( f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}" ) - logger.info("recommend option: lr=1.0 / 推奨は1.0です") + logger.warning("recommend option: lr=1.0 / 推奨は1.0です") if lr_count > 1: - logger.info( + logger.warning( f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" ) @@ -3578,7 +3578,7 @@ def get_optimizer(args, trainable_params): if optimizer_kwargs["relative_step"]: logger.info(f"relative_step is true / relative_stepがtrueです") if lr != 0.0: - logger.info(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") + logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") args.learning_rate = None # trainable_paramsがgroupだった時の処理:lrを削除する @@ -3747,7 +3747,7 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): if support_metadata: if args.in_json is not None and (args.color_aug or args.random_crop): - logger.info( + logger.warning( f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます" ) @@ -4596,7 +4596,7 @@ def sample_images_common( except ValueError as ex: logger.error(f"Exception in parsing / 解析エラー: {parg}") - logger.error(ex) + logger.error(f"{ex}") if seed is not None: torch.manual_seed(seed) diff --git a/networks/lora.py b/networks/lora.py index a37475d10..64d98425d 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -905,7 +905,7 @@ def create_modules( skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: - logger.info( + logger.warning( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 141f8780a..fdf68648f 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -920,7 +920,7 @@ def create_modules( skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: - logger.info( + logger.warning( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py index ccdcaddb9..d457accec 100644 --- a/networks/lora_interrogator.py +++ b/networks/lora_interrogator.py @@ -37,7 +37,7 @@ def interrogate(args): has_te_weight = True break if not has_te_weight: - logger.info("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") + logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") return del vae diff --git a/networks/merge_lora.py b/networks/merge_lora.py index a6ae053bd..fba504ee5 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -265,7 +265,7 @@ def str_to_dtype(p): ) if args.v2: # TODO read sai modelspec - logger.info( + logger.warning( "Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) @@ -290,7 +290,7 @@ def str_to_dtype(p): ) if v2: # TODO read sai modelspec - logger.info( + logger.warning( "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) metadata.update(sai_metadata) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 5d32bf0a2..60c195087 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -204,7 +204,7 @@ def str_to_dtype(p): ) if v2: # TODO read sai modelspec - logger.info( + logger.warning( "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) metadata.update(sai_metadata) From 799e59972833c3121b5347844c6df765636eb478 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 19:25:30 +0900 Subject: [PATCH 4/8] Removed line-breaks for readability --- finetune/tag_images_by_wd14_tagger.py | 5 ++++- gen_img_diffusers.py | 3 ++- library/config_util.py | 9 ++++++--- library/train_util.py | 21 ++++++++++++++------- networks/merge_lora_old.py | 6 ++++-- sdxl_gen_img.py | 3 ++- train_textual_inversion_XTI.py | 6 ++++-- 7 files changed, 36 insertions(+), 17 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 9c71fa16e..540a0cc16 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -237,7 +237,10 @@ def run_batch(path_imgs): with open(caption_file, "wt", encoding="utf-8") as f: f.write(tag_text + "\n") if args.debug: - logger.info(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") + logger.info("") + logger.info(f"{image_path}:") + logger.info(f"\tCharacter tags: {character_tag_text}") + logger.info(f"\tGeneral tags: {general_tag_text}") # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index fb4d87ca4..344a5ec83 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -3053,7 +3053,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # interactive valid = False while not valid: - logger.info("\nType prompt:") + logger.info("") + logger.info("Type prompt:") try: raw_prompt = input() except EOFError: diff --git a/library/config_util.py b/library/config_util.py index 382906666..deecc1e57 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -603,16 +603,19 @@ def load_user_config(file: str) -> dict: user_config = load_user_config(config_args.dataset_config) - logger.info("\n[user_config]") + logger.info("") + logger.info("[user_config]") logger.info(f'{user_config}') sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout) sanitized_user_config = sanitizer.sanitize_user_config(user_config) - logger.info("\n[sanitized_user_config]") + logger.info("") + logger.info("[sanitized_user_config]") logger.info(f'{sanitized_user_config}') blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) - logger.info("\n[blueprint]") + logger.info("") + logger.info("[blueprint]") logger.info(f'{blueprint}') diff --git a/library/train_util.py b/library/train_util.py index 6a67ce73f..719d46b68 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1980,7 +1980,8 @@ def debug_dataset(train_dataset, show_input_ids=False): epoch = 1 while True: - logger.info(f"\nepoch: {epoch}") + logger.info(f"") + logger.info(f"epoch: {epoch}") steps = (epoch - 1) * len(train_dataset) + 1 indices = list(range(len(train_dataset))) @@ -4205,7 +4206,8 @@ def save_sd_model_on_epoch_end_or_stepwise_common( ckpt_name = get_step_ckpt_name(args, ext, global_step) ckpt_file = os.path.join(args.output_dir, ckpt_name) - logger.info(f"\nsaving checkpoint: {ckpt_file}") + logger.info("") + logger.info(f"saving checkpoint: {ckpt_file}") sd_saver(ckpt_file, epoch_no, global_step) if args.huggingface_repo_id is not None: @@ -4229,7 +4231,8 @@ def save_sd_model_on_epoch_end_or_stepwise_common( else: out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step)) - logger.info(f"\nsaving model: {out_dir}") + logger.info("") + logger.info(f"saving model: {out_dir}") diffusers_saver(out_dir) if args.huggingface_repo_id is not None: @@ -4256,7 +4259,8 @@ def save_sd_model_on_epoch_end_or_stepwise_common( def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no): model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) - logger.info(f"\nsaving state at epoch {epoch_no}") + logger.info("") + logger.info(f"saving state at epoch {epoch_no}") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) @@ -4277,7 +4281,8 @@ def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, ep def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no): model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) - logger.info(f"\nsaving state at step {step_no}") + logger.info("") + logger.info(f"saving state at step {step_no}") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no)) @@ -4302,7 +4307,8 @@ def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_n def save_state_on_train_end(args: argparse.Namespace, accelerator): model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) - logger.info("\nsaving last state.") + logger.info("") + logger.info("saving last state.") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) @@ -4440,7 +4446,8 @@ def sample_images_common( if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch return - logger.info(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return diff --git a/networks/merge_lora_old.py b/networks/merge_lora_old.py index e62531301..1a1ab84ce 100644 --- a/networks/merge_lora_old.py +++ b/networks/merge_lora_old.py @@ -149,13 +149,15 @@ def str_to_dtype(p): merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) - logger.info(f"\nsaving SD model to: {args.save_to}") + logger.info("") + logger.info(f"saving SD model to: {args.save_to}") model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) else: state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) - logger.info(f"\nsaving model to: {args.save_to}") + logger.info(f"") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 8569d9076..816144985 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -2248,7 +2248,8 @@ def scale_and_round(x): # interactive valid = False while not valid: - logger.info("\nType prompt:") + logger.info("") + logger.info("Type prompt:") try: raw_prompt = input() except EOFError: diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index d9faddfdb..b38afd565 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -408,7 +408,8 @@ def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, ckpt_name) - logger.info(f"\nsaving checkpoint: {ckpt_file}") + logger.info("") + logger.info(f"saving checkpoint: {ckpt_file}") save_weights(ckpt_file, embs, save_dtype) if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) @@ -421,7 +422,8 @@ def remove_model(old_ckpt_name): # training loop for epoch in range(num_train_epochs): - logger.info(f"\nepoch {epoch+1}/{num_train_epochs}") + logger.info("") + logger.info(f"epoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 text_encoder.train() From ad942d3d965938abb678d49f96acb88d42dd9d30 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Wed, 17 Jan 2024 21:10:41 +0900 Subject: [PATCH 5/8] Use setup_logging() --- fine_tune.py | 6 ++++-- finetune/blip/blip.py | 6 ++++-- finetune/clean_captions_and_tags.py | 6 ++++-- finetune/make_captions.py | 6 ++++-- finetune/make_captions_by_git.py | 6 ++++-- finetune/merge_captions_to_metadata.py | 6 ++++-- finetune/merge_dd_tags_to_metadata.py | 6 ++++-- finetune/prepare_buckets_latents.py | 6 ++++-- finetune/tag_images_by_wd14_tagger.py | 6 ++++-- gen_img_diffusers.py | 6 ++++-- library/config_util.py | 6 ++++-- library/custom_train_functions.py | 6 ++++-- library/huggingface_util.py | 6 ++++-- library/ipex/hijacks.py | 6 ++++-- library/model_util.py | 6 ++++-- library/original_unet.py | 6 ++++-- library/sai_model_spec.py | 6 ++++-- library/sdxl_model_util.py | 6 ++++-- library/sdxl_original_unet.py | 6 ++++-- library/sdxl_train_util.py | 6 ++++-- library/slicing_vae.py | 6 ++++-- library/train_util.py | 6 ++++-- library/utils.py | 27 +++++++++++++----------- networks/check_lora_weights.py | 6 ++++-- networks/control_net_lllite.py | 6 ++++-- networks/control_net_lllite_for_train.py | 6 ++++-- networks/dylora.py | 6 ++++-- networks/extract_lora_from_dylora.py | 6 ++++-- networks/extract_lora_from_models.py | 6 ++++-- networks/lora.py | 6 ++++-- networks/lora_diffusers.py | 6 ++++-- networks/lora_fa.py | 6 ++++-- networks/lora_interrogator.py | 6 ++++-- networks/merge_lora.py | 6 ++++-- networks/merge_lora_old.py | 6 ++++-- networks/oft.py | 6 ++++-- networks/resize_lora.py | 6 ++++-- networks/sdxl_merge_lora.py | 6 ++++-- networks/svd_merge_lora.py | 6 ++++-- sdxl_gen_img.py | 6 ++++-- sdxl_minimal_inference.py | 6 ++++-- sdxl_train.py | 6 ++++-- sdxl_train_control_net_lllite.py | 6 ++++-- sdxl_train_control_net_lllite_old.py | 6 ++++-- sdxl_train_network.py | 6 ++++-- tools/cache_latents.py | 6 ++++-- tools/cache_text_encoder_outputs.py | 6 ++++-- tools/convert_diffusers20_original_sd.py | 6 ++++-- tools/detect_face_rotate.py | 6 ++++-- tools/latent_upscaler.py | 6 ++++-- tools/merge_models.py | 6 ++++-- tools/original_control_net.py | 6 ++++-- tools/resize_images_to_resolution.py | 6 ++++-- tools/show_metadata.py | 6 ++++-- train_controlnet.py | 6 ++++-- train_db.py | 6 ++++-- train_network.py | 6 ++++-- train_textual_inversion.py | 6 ++++-- train_textual_inversion_XTI.py | 6 ++++-- 59 files changed, 247 insertions(+), 128 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index d6c2dcf9e..28a272d94 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -23,8 +23,10 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) import library.train_util as train_util import library.config_util as config_util from library.config_util import ( diff --git a/finetune/blip/blip.py b/finetune/blip/blip.py index 606b8cc9d..7d192cb26 100644 --- a/finetune/blip/blip.py +++ b/finetune/blip/blip.py @@ -21,8 +21,10 @@ import os from urllib.parse import urlparse from timm.models.hub import download_cached_file -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class BLIP_Base(nn.Module): def __init__(self, diff --git a/finetune/clean_captions_and_tags.py b/finetune/clean_captions_and_tags.py index fa441ba55..5aeb17425 100644 --- a/finetune/clean_captions_and_tags.py +++ b/finetune/clean_captions_and_tags.py @@ -8,8 +8,10 @@ import re from tqdm import tqdm -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 3bc916dfe..5d6d48047 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -15,8 +15,10 @@ sys.path.append(os.path.dirname(__file__)) from blip.blip import blip_decoder, is_url import library.train_util as train_util -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index 777c08dc8..e50ed8ca8 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -10,8 +10,10 @@ from transformers.generation.utils import GenerationMixin import library.train_util as train_util -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/finetune/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py index 42e915309..60765b863 100644 --- a/finetune/merge_captions_to_metadata.py +++ b/finetune/merge_captions_to_metadata.py @@ -5,8 +5,10 @@ from tqdm import tqdm import library.train_util as train_util import os -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def main(args): assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py index df9bbecea..9ef8f14b0 100644 --- a/finetune/merge_dd_tags_to_metadata.py +++ b/finetune/merge_dd_tags_to_metadata.py @@ -5,8 +5,10 @@ from tqdm import tqdm import library.train_util as train_util import os -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def main(args): assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 8cc1449e7..4292ac133 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -13,8 +13,10 @@ import library.model_util as model_util import library.train_util as train_util -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 77d0d1adb..7a751b177 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -11,8 +11,10 @@ from tqdm import tqdm import library.train_util as train_util -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # from wd14 tagger IMAGE_SIZE = 448 diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 6ff1019c1..54688e3c2 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -109,8 +109,10 @@ from library.original_unet import FlashAttentionFunction from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # scheduler: SCHEDULER_LINEAR_START = 0.00085 diff --git a/library/config_util.py b/library/config_util.py index d7f83adbc..4318539fb 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -39,8 +39,10 @@ ControlNetDataset, DatasetGroup, ) -from .utils import get_my_logger -logger = get_my_logger(__name__) +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def add_config_arguments(parser: argparse.ArgumentParser): parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 33dd3033d..a56474622 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -3,8 +3,10 @@ import random import re from typing import List, Optional, Union -from .utils import get_my_logger -logger = get_my_logger(__name__) +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def prepare_scheduler_for_custom_training(noise_scheduler, device): if hasattr(noise_scheduler, "all_snr"): diff --git a/library/huggingface_util.py b/library/huggingface_util.py index dd3d03bbb..57b19d982 100644 --- a/library/huggingface_util.py +++ b/library/huggingface_util.py @@ -4,8 +4,10 @@ import argparse import os from library.utils import fire_in_thread -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None): api = HfApi( diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 93db1eefa..afacfbf97 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -1,8 +1,10 @@ import contextlib import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return diff --git a/library/model_util.py b/library/model_util.py index 790bc631b..06e1a9f8a 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -19,8 +19,10 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel from safetensors.torch import load_file, save_file from library.original_unet import UNet2DConditionModel -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # DiffUsers版StableDiffusionのモデルパラメータ NUM_TRAIN_TIMESTEPS = 1000 diff --git a/library/original_unet.py b/library/original_unet.py index 9aecc966f..a88efda1e 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -113,8 +113,10 @@ from torch import nn from torch.nn import functional as F from einops import rearrange -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0] diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index d53a7959c..a63bd82ec 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -5,8 +5,10 @@ import os from typing import List, Optional, Tuple, Union import safetensors -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) r""" # Metadata Example diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 853a59c58..fc765b625 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -7,8 +7,10 @@ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from library import model_util from library import sdxl_original_unet -from .utils import get_my_logger -logger = get_my_logger(__name__) +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) VAE_SCALE_FACTOR = 0.13025 MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0" diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index c8beefed4..de01edb99 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -30,8 +30,10 @@ from torch import nn from torch.nn import functional as F from einops import rearrange -from .utils import get_my_logger -logger = get_my_logger(__name__) +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) IN_CHANNELS: int = 4 OUT_CHANNELS: int = 4 diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 3847a89fa..aedab6c76 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -9,8 +9,10 @@ from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline -from .utils import get_my_logger -logger = get_my_logger(__name__) +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) TOKENIZER1_PATH = "openai/clip-vit-large-patch14" TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" diff --git a/library/slicing_vae.py b/library/slicing_vae.py index 0e07d0bed..ea7653429 100644 --- a/library/slicing_vae.py +++ b/library/slicing_vae.py @@ -26,8 +26,10 @@ from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution from diffusers.models.autoencoder_kl import AutoencoderKLOutput -from .utils import get_my_logger -logger = get_my_logger(__name__) +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def slice_h(x, num_slices): # slice with pad 1 both sides: to eliminate side effect of padding of conv2d diff --git a/library/train_util.py b/library/train_util.py index ddc77740a..8aa2e8833 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -65,8 +65,10 @@ import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # from library.attention_processors import FlashAttnProcessor # from library.hypernetwork import replace_attentions_for_hypernetwork from library.original_unet import UNet2DConditionModel diff --git a/library/utils.py b/library/utils.py index b491cd761..7e65ae306 100644 --- a/library/utils.py +++ b/library/utils.py @@ -5,17 +5,20 @@ def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() -def get_my_logger(name: str): - logger = logging.getLogger(name) - logger.setLevel(logging.INFO) +def setup_logging(log_level=logging.INFO): + if logging.root.handlers: # Already configured + return + try: + from rich.logging import RichHandler - stream_handler = logging.StreamHandler() - stream_handler.setLevel(logging.INFO) + handler = RichHandler() + except ImportError: + handler = logging.StreamHandler() - myformat = '%(asctime)s\t[%(levelname)s]\t%(filename)s:%(lineno)d\t%(message)s' - date_format = '%Y-%m-%d %H:%M:%S' - formatter = logging.Formatter(myformat, date_format) - stream_handler.setFormatter(formatter) - logger.addHandler(stream_handler) - - return logger + formatter = logging.Formatter( + fmt="%(asctime)s %(levelname)s [%(name)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + handler.setFormatter(formatter) + logging.root.setLevel(log_level) + logging.root.addHandler(handler) diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 10650d152..6ec60a89b 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -2,8 +2,10 @@ import os import torch from safetensors.torch import load_file -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def main(file): logger.info(f"loading: {file}") diff --git a/networks/control_net_lllite.py b/networks/control_net_lllite.py index 4f9a82345..c9377bee8 100644 --- a/networks/control_net_lllite.py +++ b/networks/control_net_lllite.py @@ -2,8 +2,10 @@ from typing import Optional, List, Type import torch from library import sdxl_original_unet -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied SKIP_INPUT_BLOCKS = False diff --git a/networks/control_net_lllite_for_train.py b/networks/control_net_lllite_for_train.py index a202712ea..65b3520cf 100644 --- a/networks/control_net_lllite_for_train.py +++ b/networks/control_net_lllite_for_train.py @@ -6,8 +6,10 @@ from typing import Optional, List, Type import torch from library import sdxl_original_unet -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied SKIP_INPUT_BLOCKS = False diff --git a/networks/dylora.py b/networks/dylora.py index 5781cee51..262699014 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -15,8 +15,10 @@ from typing import List, Tuple, Union import torch from torch import nn -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class DyLoRAModule(torch.nn.Module): """ diff --git a/networks/extract_lora_from_dylora.py b/networks/extract_lora_from_dylora.py index 35747ebfc..1184cd8a5 100644 --- a/networks/extract_lora_from_dylora.py +++ b/networks/extract_lora_from_dylora.py @@ -10,8 +10,10 @@ from tqdm import tqdm from library import train_util, model_util import numpy as np -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name): if model_util.is_safetensors(file_name): diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 99c9ca827..41a5db2da 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -11,8 +11,10 @@ from tqdm import tqdm from library import sai_model_spec, model_util, sdxl_model_util import lora -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # CLAMP_QUANTILE = 0.99 # MIN_DIFF = 1e-1 diff --git a/networks/lora.py b/networks/lora.py index 64d98425d..eaf656ac8 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -11,8 +11,10 @@ import numpy as np import torch import re -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") diff --git a/networks/lora_diffusers.py b/networks/lora_diffusers.py index 88f33c0e3..31ca6feb3 100644 --- a/networks/lora_diffusers.py +++ b/networks/lora_diffusers.py @@ -10,8 +10,10 @@ from tqdm import tqdm from transformers import CLIPTextModel import torch -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def make_unet_conversion_map() -> Dict[str, str]: unet_conversion_map_layer = [] diff --git a/networks/lora_fa.py b/networks/lora_fa.py index fdf68648f..919222ce8 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -14,8 +14,10 @@ import numpy as np import torch import re -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py index d457accec..d00aaaf08 100644 --- a/networks/lora_interrogator.py +++ b/networks/lora_interrogator.py @@ -9,8 +9,10 @@ import library.model_util as model_util import lora -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う diff --git a/networks/merge_lora.py b/networks/merge_lora.py index fba504ee5..fea8a3f32 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -7,8 +7,10 @@ from library import sai_model_spec, train_util import library.model_util as model_util import lora -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": diff --git a/networks/merge_lora_old.py b/networks/merge_lora_old.py index 1a1ab84ce..334d127b7 100644 --- a/networks/merge_lora_old.py +++ b/networks/merge_lora_old.py @@ -6,8 +6,10 @@ from safetensors.torch import load_file, save_file import library.model_util as model_util import lora -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == '.safetensors': diff --git a/networks/oft.py b/networks/oft.py index 68843bce9..461a98698 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -8,8 +8,10 @@ import numpy as np import torch import re -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 60cdc008f..c5932a893 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -8,8 +8,10 @@ from tqdm import tqdm from library import train_util, model_util import numpy as np -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) MIN_SV = 1e-6 diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index f4960ed8b..3383a80de 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -8,8 +8,10 @@ from library import sai_model_spec, sdxl_model_util, train_util import library.model_util as model_util import lora -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 60c195087..4d8f72546 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -8,8 +8,10 @@ from library import sai_model_spec, train_util import library.model_util as model_util import lora -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) CLAMP_QUANTILE = 0.99 diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index edaf804ec..3e6961326 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -60,8 +60,10 @@ from library.sdxl_original_unet import InferSdxlUNet2DConditionModel from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # scheduler: SCHEDULER_LINEAR_START = 0.00085 diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 13a5be402..e06d9a3e7 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -25,8 +25,10 @@ from library import model_util, sdxl_model_util import networks.lora as lora -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # scheduler: このあたりの設定はSD1/2と同じでいいらしい # scheduler: The settings around here seem to be the same as SD1/2 diff --git a/sdxl_train.py b/sdxl_train.py index 1d4c72b3b..af73d6684 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -25,8 +25,10 @@ from library import sdxl_model_util import library.train_util as train_util -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) import library.config_util as config_util import library.sdxl_train_util as sdxl_train_util from library.config_util import ( diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 555aa7d4c..cf8d5ff28 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -47,8 +47,10 @@ apply_debiased_estimation, ) import networks.control_net_lllite_for_train as control_net_lllite_for_train -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index d6e7e4c01..b30ccbc58 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -43,8 +43,10 @@ apply_debiased_estimation, ) import networks.control_net_lllite as control_net_lllite -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 22b2bb7d1..34e0302b8 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -12,8 +12,10 @@ pass from library import sdxl_model_util, sdxl_train_util, train_util import train_network -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class SdxlNetworkTrainer(train_network.NetworkTrainer): def __init__(self): diff --git a/tools/cache_latents.py b/tools/cache_latents.py index b2faf5b5b..e25506e41 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -16,8 +16,10 @@ ConfigSanitizer, BlueprintGenerator, ) -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: train_util.prepare_dataset_args(args, True) diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 0064b8c96..46bffc4e0 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -16,8 +16,10 @@ ConfigSanitizer, BlueprintGenerator, ) -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: train_util.prepare_dataset_args(args, True) diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py index 4b07cfdbb..572ee2f0c 100644 --- a/tools/convert_diffusers20_original_sd.py +++ b/tools/convert_diffusers20_original_sd.py @@ -6,8 +6,10 @@ from diffusers import StableDiffusionPipeline import library.model_util as model_util -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def convert(args): # 引数を確認する diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py index ad1f786a4..bbc643edc 100644 --- a/tools/detect_face_rotate.py +++ b/tools/detect_face_rotate.py @@ -15,8 +15,10 @@ from anime_face_detector import create_detector from tqdm import tqdm import numpy as np -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) KP_REYE = 11 KP_LEYE = 19 diff --git a/tools/latent_upscaler.py b/tools/latent_upscaler.py index 8b09f02e1..4fa599e29 100644 --- a/tools/latent_upscaler.py +++ b/tools/latent_upscaler.py @@ -14,8 +14,10 @@ from torch import nn from tqdm import tqdm from PIL import Image -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1): diff --git a/tools/merge_models.py b/tools/merge_models.py index 975f0413d..8f1fbf2f8 100644 --- a/tools/merge_models.py +++ b/tools/merge_models.py @@ -5,8 +5,10 @@ from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def is_unet_key(key): # VAE or TextEncoder, the last one is for SDXL diff --git a/tools/original_control_net.py b/tools/original_control_net.py index 2aa0a2d90..21392dc47 100644 --- a/tools/original_control_net.py +++ b/tools/original_control_net.py @@ -7,8 +7,10 @@ from library.original_unet import UNet2DConditionModel, SampleOutput import library.model_util as model_util -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class ControlNetInfo(NamedTuple): unet: Any diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index 979f7f141..b8069fc1d 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -6,8 +6,10 @@ import math from PIL import Image import numpy as np -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): # Split the max_resolution string by "," and strip any whitespaces diff --git a/tools/show_metadata.py b/tools/show_metadata.py index 784cdf3ad..05bfbe0a4 100644 --- a/tools/show_metadata.py +++ b/tools/show_metadata.py @@ -1,8 +1,10 @@ import json import argparse from safetensors import safe_open -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) diff --git a/train_controlnet.py b/train_controlnet.py index 5970acd75..1548df94a 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -40,8 +40,10 @@ pyramid_noise_like, apply_noise_offset, ) -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): diff --git a/train_db.py b/train_db.py index a80b83ef2..89a75b0c1 100644 --- a/train_db.py +++ b/train_db.py @@ -40,8 +40,10 @@ scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, ) -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # perlin_noise, diff --git a/train_network.py b/train_network.py index b8fb82e6b..1c105f04d 100644 --- a/train_network.py +++ b/train_network.py @@ -46,8 +46,10 @@ add_v_prediction_like_loss, apply_debiased_estimation, ) -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class NetworkTrainer: def __init__(self): diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 9bfbb5e18..ebeb1263f 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -37,8 +37,10 @@ add_v_prediction_like_loss, apply_debiased_estimation, ) -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) imagenet_templates_small = [ "a photo of a {}", diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 045136e3f..a4ca9fe85 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -38,8 +38,10 @@ ) import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI -from library.utils import get_my_logger -logger = get_my_logger(__name__) +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) imagenet_templates_small = [ "a photo of a {}", From 3987726cf977a5d3cabf2587fcd50e124044a1b2 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Wed, 17 Jan 2024 21:18:34 +0900 Subject: [PATCH 6/8] Add rich to requirements.txt --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 8517d95ac..9dfaa3fa9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,5 +29,7 @@ huggingface-hub==0.20.1 # protobuf==3.20.3 # open clip for SDXL open-clip-torch==2.20.0 +# For logging +rich==13.7.0 # for kohya_ss library -e . From 8fd7f23b1367570fb1ff6be2dcd3215555de301b Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Wed, 17 Jan 2024 21:21:31 +0900 Subject: [PATCH 7/8] Make simple --- library/utils.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/library/utils.py b/library/utils.py index 7e65ae306..ebcaba308 100644 --- a/library/utils.py +++ b/library/utils.py @@ -1,22 +1,21 @@ +import logging import threading from typing import * -import logging + def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() + def setup_logging(log_level=logging.INFO): if logging.root.handlers: # Already configured return - try: - from rich.logging import RichHandler - handler = RichHandler() - except ImportError: - handler = logging.StreamHandler() + from rich.logging import RichHandler + handler = RichHandler() formatter = logging.Formatter( - fmt="%(asctime)s %(levelname)s [%(name)s] %(message)s", + fmt="%(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) handler.setFormatter(formatter) From 3981ba408ef4ebad0c7f7d2e55061ae9f3d4e26a Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Wed, 17 Jan 2024 21:27:50 +0900 Subject: [PATCH 8/8] Use logger instead of print --- gen_img_diffusers.py | 12 ++++++------ library/original_unet.py | 4 ++-- library/sdxl_model_util.py | 2 +- library/sdxl_original_unet.py | 4 ++-- sdxl_gen_img.py | 10 +++++----- tools/canny.py | 6 +++++- 6 files changed, 21 insertions(+), 17 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 54688e3c2..11077578d 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -3184,35 +3184,35 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 1 ds_depth_1 = int(m.group(1)) - print(f"deep shrink depth 1: {ds_depth_1}") + logger.info(f"deep shrink depth 1: {ds_depth_1}") continue m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 1 ds_timesteps_1 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 1: {ds_timesteps_1}") + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") continue m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 2 ds_depth_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink depth 2: {ds_depth_2}") + logger.info(f"deep shrink depth 2: {ds_depth_2}") continue m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 2 ds_timesteps_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 2: {ds_timesteps_2}") + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") continue m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink ratio ds_ratio = float(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink ratio: {ds_ratio}") + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink ratio: {ds_ratio}") continue except ValueError as ex: diff --git a/library/original_unet.py b/library/original_unet.py index a88efda1e..d4f978010 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -1713,14 +1713,14 @@ def __call__(self, *args, **kwargs): def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): if ds_depth_1 is None: - print("Deep Shrink is disabled.") + logger.info("Deep Shrink is disabled.") self.ds_depth_1 = None self.ds_timesteps_1 = None self.ds_depth_2 = None self.ds_timesteps_2 = None self.ds_ratio = None else: - print( + logger.info( f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" ) self.ds_depth_1 = ds_depth_1 diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index fc765b625..f03f1bae5 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -134,7 +134,7 @@ def convert_key(key): # temporary workaround for text_projection.weight.weight for Playground-v2 if "text_projection.weight.weight" in new_sd: - print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") + logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"] del new_sd["text_projection.weight.weight"] diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index de01edb99..673cf9f65 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -1138,14 +1138,14 @@ def __call__(self, *args, **kwargs): def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): if ds_depth_1 is None: - print("Deep Shrink is disabled.") + logger.info("Deep Shrink is disabled.") self.ds_depth_1 = None self.ds_timesteps_1 = None self.ds_depth_2 = None self.ds_timesteps_2 = None self.ds_ratio = None else: - print( + logger.info( f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" ) self.ds_depth_1 = ds_depth_1 diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 3e6961326..afeac78a6 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -2421,35 +2421,35 @@ def scale_and_round(x): m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 1 ds_depth_1 = int(m.group(1)) - print(f"deep shrink depth 1: {ds_depth_1}") + logger.info(f"deep shrink depth 1: {ds_depth_1}") continue m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 1 ds_timesteps_1 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 1: {ds_timesteps_1}") + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") continue m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 2 ds_depth_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink depth 2: {ds_depth_2}") + logger.info(f"deep shrink depth 2: {ds_depth_2}") continue m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 2 ds_timesteps_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 2: {ds_timesteps_2}") + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") continue m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink ratio ds_ratio = float(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink ratio: {ds_ratio}") + logger.info(f"deep shrink ratio: {ds_ratio}") continue except ValueError as ex: diff --git a/tools/canny.py b/tools/canny.py index 5e0806898..f2190975c 100644 --- a/tools/canny.py +++ b/tools/canny.py @@ -1,6 +1,10 @@ import argparse import cv2 +import logging +from library.utils import setup_logging +setup_logging() +logger = logging.getLogger(__name__) def canny(args): img = cv2.imread(args.input) @@ -10,7 +14,7 @@ def canny(args): # canny_img = 255 - canny_img cv2.imwrite(args.output, canny_img) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: