Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v22.6.2 #1995

Merged
merged 91 commits into from
Feb 24, 2024
Merged

v22.6.2 #1995

Changes from 1 commit
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
6849546
add gradual latent
kohya-ss Nov 23, 2023
610566f
Update README.md
kohya-ss Nov 23, 2023
2897a89
Merge branch 'dev' into gradual_latent_hires_fix
kohya-ss Nov 26, 2023
298c6c2
fix gradual latent cannot be disabled
kohya-ss Nov 26, 2023
2c50ea0
apply unsharp mask
kohya-ss Nov 27, 2023
29b6fa6
add unsharp mask
kohya-ss Nov 28, 2023
2952bca
fix strength error
kohya-ss Dec 1, 2023
7a4e507
add target_x flag (not sure this impl is correct)
kohya-ss Dec 3, 2023
e8c3a02
Merge branch 'dev' into gradual_latent_hires_fix
kohya-ss Dec 7, 2023
9278031
Merge branch 'dev' into gradual_latent_hires_fix
kohya-ss Dec 11, 2023
07ef03d
fix controlnet to work with gradual latent
kohya-ss Dec 11, 2023
d61ecb2
enable comment in prompt file, record raw prompt to metadata
kohya-ss Dec 11, 2023
da9b34f
Merge branch 'dev' into gradual_latent_hires_fix
kohya-ss Jan 4, 2024
2e4bee6
Log accelerator device
akx Jan 16, 2024
afc3870
Refactor memory cleaning into a single function
akx Jan 16, 2024
478156b
Refactor device determination to function; add MPS fallback
akx Jan 16, 2024
8f6f734
Merge branch 'dev' into gradual_latent_hires_fix
kohya-ss Jan 27, 2024
ccc3a48
Update IPEX Libs
Disty0 Jan 28, 2024
d4b9568
fix broken import in svd_merge_lora script
mgz-dev Jan 28, 2024
988dee0
IPEX torch.tensor FP64 workaround
Disty0 Jan 29, 2024
9d7729c
Merge pull request #1086 from Disty0/dev
kohya-ss Jan 31, 2024
7f948db
Merge pull request #1087 from mgz-dev/fix-imports-on-svd_merge_lora
kohya-ss Jan 31, 2024
2ca4d0c
Merge pull request #1054 from akx/mps
kohya-ss Jan 31, 2024
a6a2b5a
Fix IPEX support and add XPU device to device_utils
Disty0 Jan 31, 2024
9f0f0d5
Merge pull request #1092 from Disty0/dev_device_support
kohya-ss Feb 1, 2024
5cca1fd
add highvram option and do not clear cache in caching latents
kohya-ss Feb 1, 2024
1567ce1
Enable distributed sample image generation on multi-GPU enviroment (#…
DKnight54 Feb 3, 2024
11aced3
simplify multi-GPU sample generation
kohya-ss Feb 3, 2024
2f9a344
fix typo
kohya-ss Feb 3, 2024
6269682
unificaition of gen scripts for SD and SDXL, work in progress
kohya-ss Feb 3, 2024
bf2de56
fix formatting in resize_lora.py
mgz-dev Feb 4, 2024
1492bcb
add --new_conv_rank option
mgz-dev Feb 4, 2024
e793d77
reduce peak VRAM in sample gen
kohya-ss Feb 4, 2024
5f6bf29
Replace print with logger if they are logs (#905)
shirayu Feb 4, 2024
6279b33
fallback to basic logging if rich is not installed
kohya-ss Feb 4, 2024
efd3b58
Add logging arguments and update logging setup
kohya-ss Feb 4, 2024
74fe045
add comment for get_preferred_device
kohya-ss Feb 8, 2024
9b8ea12
update log initialization without rich
kohya-ss Feb 8, 2024
055f02e
add logging args for training scripts
kohya-ss Feb 8, 2024
5d9e287
make rich to output to stderr instead of stdout
kohya-ss Feb 8, 2024
7202596
log to print tag frequencies
kohya-ss Feb 10, 2024
f897d55
Merge pull request #1113 from kohya-ss/dev_multi_gpu_sample_gen
kohya-ss Feb 11, 2024
75ecb04
Merge branch 'dev' into dev_device_support
kohya-ss Feb 11, 2024
e24d960
add clean_memory_on_device and use it from training
kohya-ss Feb 12, 2024
e579648
fix help for highvram arg
kohya-ss Feb 12, 2024
672851e
Merge branch 'dev' into dev_improve_log
kohya-ss Feb 12, 2024
20ae603
Merge branch 'dev' into gradual_latent_hires_fix
kohya-ss Feb 12, 2024
35c6053
Merge pull request #1104 from kohya-ss/dev_improve_log
kohya-ss Feb 12, 2024
98f42d3
Merge branch 'dev' into gradual_latent_hires_fix
kohya-ss Feb 12, 2024
c748719
fix indent
kohya-ss Feb 12, 2024
358ca20
Merge branch 'dev' into dev_device_support
kohya-ss Feb 12, 2024
d3745db
add args for logging
kohya-ss Feb 12, 2024
cbe9c5d
supprt deep shink with regional lora, add prompter module
kohya-ss Feb 12, 2024
41d32c0
Merge pull request #1117 from kohya-ss/gradual_latent_hires_fix
kohya-ss Feb 12, 2024
93bed60
fix to work `--console_log_xxx` options
kohya-ss Feb 12, 2024
71ebcc5
update readme and gradual latent doc
kohya-ss Feb 12, 2024
baa0e97
Merge branch 'dev' into dev_device_support
kohya-ss Feb 17, 2024
42f3318
Merge pull request #1116 from kohya-ss/dev_device_support
kohya-ss Feb 17, 2024
75e4a95
update readme
kohya-ss Feb 17, 2024
d1fb480
format by black
kohya-ss Feb 18, 2024
6a9d9be
Fix lora_network_weights
bmaltais Feb 18, 2024
83bf2e0
chore(docker): rewrite Dockerfile
jim60105 Feb 17, 2024
dc94512
ci(docker): Add docker CI
jim60105 Feb 17, 2024
5dc9db6
Revert "ci(docker): Add docker CI"
jim60105 Feb 17, 2024
8330597
chore(docker): Add EXPOSE ports and change final base image to python…
jim60105 Feb 17, 2024
543d12f
chore(docker): fix dependencies for slim image
jim60105 Feb 17, 2024
d7add28
chore(docker): Add label
jim60105 Feb 18, 2024
a6f1ed2
fix dylora create_modules error
tamlog06 Feb 18, 2024
07116dc
Update options.md
mikeboensel Feb 18, 2024
a63c49c
Update options.md
mikeboensel Feb 18, 2024
5b19748
Update options.md
mikeboensel Feb 18, 2024
39e3a4b
Label clarifications
mikeboensel Feb 18, 2024
f71b3cf
Merge pull request #1978 from mikeboensel/patch-1
bmaltais Feb 18, 2024
7a49955
Merge pull request #1979 from mikeboensel/patch-2
bmaltais Feb 18, 2024
6a6c932
Merge pull request #1980 from mikeboensel/patch-3
bmaltais Feb 18, 2024
78e2df1
Merge pull request #1976 from jim60105/master
bmaltais Feb 18, 2024
2d0ed8e
Merge pull request #1981 from mikeboensel/patch-4
bmaltais Feb 18, 2024
86279c8
Merge branch 'dev' into DyLoRA-xl
kohya-ss Feb 24, 2024
488d187
Merge pull request #1126 from tamlog06/DyLoRA-xl
kohya-ss Feb 24, 2024
f413201
fix to work with cpu_count() == 1 closes #1134
kohya-ss Feb 24, 2024
24092e6
update einops to 0.7.0 #1122
kohya-ss Feb 24, 2024
fb9110b
format by black
kohya-ss Feb 24, 2024
0e70360
Merge branch 'dev' into resize_lora-add-rank-for-conv
kohya-ss Feb 24, 2024
738c397
Merge pull request #1102 from mgz-dev/resize_lora-add-rank-for-conv
kohya-ss Feb 24, 2024
52b3799
fix format, add new conv rank to metadata comment
kohya-ss Feb 24, 2024
8b7c142
some log output to print
kohya-ss Feb 24, 2024
81e8af6
fix ipex init
kohya-ss Feb 24, 2024
a21218b
update readme
kohya-ss Feb 24, 2024
e69d341
Merge pull request #1136 from kohya-ss/dev
kohya-ss Feb 24, 2024
a20c2bd
Merge branch 'main' of https://github.com/kohya-ss/sd-scripts into dev
bmaltais Feb 24, 2024
822d94c
Merge branch 'dev' of https://github.com/bmaltais/kohya_ss into dev
bmaltais Feb 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
supprt deep shink with regional lora, add prompter module
kohya-ss committed Feb 12, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit cbe9c5dc068cd81c5f7d53c7aeba601d5241e7f9
168 changes: 130 additions & 38 deletions gen_img.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,8 @@
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob
import importlib
import importlib.util
import sys
import inspect
import time
import zipfile
@@ -333,6 +335,10 @@ def __init__(
self.scheduler = scheduler
self.safety_checker = None

self.clip_vision_model: CLIPVisionModelWithProjection = None
self.clip_vision_processor: CLIPImageProcessor = None
self.clip_vision_strength = 0.0

# Textual Inversion
self.token_replacements_list = []
for _ in range(len(self.text_encoders)):
@@ -419,6 +425,7 @@ def __call__(
callback_steps: Optional[int] = 1,
img2img_noise=None,
clip_guide_images=None,
emb_normalize_mode: str = "original",
**kwargs,
):
# TODO support secondary prompt
@@ -493,6 +500,7 @@ def __call__(
clip_skip=self.clip_skip,
token_replacer=token_replacer,
device=self.device,
emb_normalize_mode=emb_normalize_mode,
**kwargs,
)
tes_text_embs.append(text_embeddings)
@@ -508,6 +516,7 @@ def __call__(
clip_skip=self.clip_skip,
token_replacer=token_replacer,
device=self.device,
emb_normalize_mode=emb_normalize_mode,
**kwargs,
)
tes_real_uncond_embs.append(real_uncond_embeddings)
@@ -1099,7 +1108,7 @@ def get_unweighted_text_embeddings(
# in sdxl, value of clip_skip is same for Text Encoder 1 and 2
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
text_embedding = enc_out["hidden_states"][-clip_skip]
if not is_sdxl: # SD 1.5 requires final_layer_norm
if not is_sdxl: # SD 1.5 requires final_layer_norm
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
if pool is None:
pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided
@@ -1122,7 +1131,7 @@ def get_unweighted_text_embeddings(
else:
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
text_embeddings = enc_out["hidden_states"][-clip_skip]
if not is_sdxl: # SD 1.5 requires final_layer_norm
if not is_sdxl: # SD 1.5 requires final_layer_norm
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this
if pool is not None:
@@ -1143,6 +1152,7 @@ def get_weighted_text_embeddings(
clip_skip: int = 1,
token_replacer=None,
device=None,
emb_normalize_mode: Optional[str] = "original", # "original", "abs", "none"
**kwargs,
):
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
@@ -1239,16 +1249,34 @@ def get_weighted_text_embeddings(
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
# →全体でいいんじゃないかな

if (not skip_parsing) and (not skip_weighting):
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if emb_normalize_mode == "abs":
previous_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
previous_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)

elif emb_normalize_mode == "none":
text_embeddings *= prompt_weights.unsqueeze(-1)
if uncond_prompt is not None:
uncond_embeddings *= uncond_weights.unsqueeze(-1)

else: # "original"
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)

if uncond_prompt is not None:
return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens
@@ -1427,6 +1455,27 @@ class BatchData(NamedTuple):
ext: BatchDataExt


class ListPrompter:
def __init__(self, prompts: List[str]):
self.prompts = prompts
self.index = 0

def shuffle(self):
random.shuffle(self.prompts)

def __len__(self):
return len(self.prompts)

def __call__(self, *args, **kwargs):
if self.index >= len(self.prompts):
self.index = 0 # reset
return None

prompt = self.prompts[self.index]
self.index += 1
return prompt


def main(args):
if args.fp16:
dtype = torch.float16
@@ -1951,15 +2000,35 @@ def __getattr__(self, item):
token_embeds2[token_id] = embed

# promptを取得する
prompt_list = None
if args.from_file is not None:
print(f"reading prompts from {args.from_file}")
with open(args.from_file, "r", encoding="utf-8") as f:
prompt_list = f.read().splitlines()
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
prompter = ListPrompter(prompt_list)

elif args.from_module is not None:

def load_module_from_path(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec is None:
raise ImportError(f"Module '{module_name}' cannot be loaded from '{file_path}'")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module

print(f"reading prompts from module: {args.from_module}")
prompt_module = load_module_from_path("prompt_module", args.from_module)

prompter = prompt_module.get_prompter(args, pipe, networks)

elif args.prompt is not None:
prompt_list = [args.prompt]
prompter = ListPrompter([args.prompt])

else:
prompt_list = []
prompter = None # interactive mode

if args.interactive:
args.n_iter = 1
@@ -2026,14 +2095,16 @@ def resize_images(imgs, size):
mask_images = None

# promptがないとき、画像のPngInfoから取得する
if init_images is not None and len(prompt_list) == 0 and not args.interactive:
if init_images is not None and prompter is None and not args.interactive:
print("get prompts from images' metadata")
prompt_list = []
for img in init_images:
if "prompt" in img.text:
prompt = img.text["prompt"]
if "negative-prompt" in img.text:
prompt += " --n " + img.text["negative-prompt"]
prompt_list.append(prompt)
prompter = ListPrompter(prompt_list)

# プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する)
l = []
@@ -2105,15 +2176,18 @@ def resize_images(imgs, size):
else:
guide_images = None

# seed指定時はseedを決めておく
# 新しい乱数生成器を作成する
if args.seed is not None:
# dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう
random.seed(args.seed)
predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)]
if len(predefined_seeds) == 1:
predefined_seeds[0] = args.seed
if prompt_list and len(prompt_list) == 1 and args.images_per_prompt == 1:
# 引数のseedをそのまま使う
def fixed_seed(*args, **kwargs):
return args.seed

seed_random = SimpleNamespace(randint=fixed_seed)
else:
seed_random = random.Random(args.seed)
else:
predefined_seeds = None
seed_random = random.Random()

# デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み)
if args.W is None:
@@ -2127,11 +2201,14 @@ def resize_images(imgs, size):

for gen_iter in range(args.n_iter):
print(f"iteration {gen_iter+1}/{args.n_iter}")
iter_seed = random.randint(0, 0x7FFFFFFF)
if args.iter_same_seed:
iter_seed = seed_random.randint(0, 2**32 - 1)
else:
iter_seed = None

# shuffle prompt list
if args.shuffle_prompts:
random.shuffle(prompt_list)
prompter.shuffle()

# バッチ処理の関数
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
@@ -2352,7 +2429,8 @@ def scale_and_round(x):
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
n.set_multiplier(m)
if regional_network:
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
# TODO バッチから ds_ratio を取り出すべき
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared, unet.ds_ratio)

if not regional_network and network_pre_calc:
for n in networks:
@@ -2386,6 +2464,7 @@ def scale_and_round(x):
return_latents=return_latents,
clip_prompts=clip_prompts,
clip_guide_images=guide_images,
emb_normalize_mode=args.emb_normalize_mode,
)
if highres_1st and not args.highres_fix_save_1st: # return images or latents
return images
@@ -2451,8 +2530,8 @@ def scale_and_round(x):
prompt_index = 0
global_step = 0
batch_data = []
while args.interactive or prompt_index < len(prompt_list):
if len(prompt_list) == 0:
while True:
if args.interactive:
# interactive
valid = False
while not valid:
@@ -2466,7 +2545,9 @@ def scale_and_round(x):
if not valid: # EOF, end app
break
else:
raw_prompt = prompt_list[prompt_index]
raw_prompt = prompter(args, pipe, seed_random, iter_seed, prompt_index, global_step)
if raw_prompt is None:
break

# sd-dynamic-prompts like variants:
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
@@ -2513,7 +2594,8 @@ def scale_and_round(x):

prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
length = len(prompter) if hasattr(prompter, "__len__") else 0
print(f"prompt {prompt_index+1}/{length}: {prompt}")

for parg in prompt_args[1:]:
try:
@@ -2731,23 +2813,17 @@ def scale_and_round(x):

# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う
# num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う
if len(seeds) > 0:
seed = seeds.pop(0)
else:
if predefined_seeds is not None:
if len(predefined_seeds) > 0:
seed = predefined_seeds.pop(0)
else:
print("predefined seeds are exhausted")
seed = None
elif args.iter_same_seed:
seeds = iter_seed
if args.iter_same_seed:
seed = iter_seed
else:
seed = None # 前のを消す

if seed is None:
seed = random.randint(0, 0x7FFFFFFF)
seed = seed_random.randint(0, 2**32 - 1)
if args.interactive:
print(f"seed: {seed}")

@@ -2853,6 +2929,15 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む",
)
parser.add_argument(
"--from_module",
type=str,
default=None,
help="if specified, load prompts from this module / 指定時はプロンプトをモジュールから読み込む",
)
parser.add_argument(
"--prompter_module_args", type=str, default=None, help="args for prompter module / prompterモジュールの引数"
)
parser.add_argument(
"--interactive",
action="store_true",
@@ -3067,6 +3152,13 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
)
parser.add_argument(
"--emb_normalize_mode",
type=str,
default="original",
choices=["original", "none", "abs"],
help="embedding normalization mode / embeddingの正規化モード",
)
parser.add_argument(
"--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像"
)
38 changes: 31 additions & 7 deletions networks/lora.py
Original file line number Diff line number Diff line change
@@ -12,8 +12,10 @@
import torch
import re
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)

RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -248,7 +250,8 @@ def get_mask_for_x(self, x):
if mask is None:
# raise ValueError(f"mask is None for resolution {area}")
# emb_layers in SDXL doesn't have mask
# logger.info(f"mask is None for resolution {area}, {x.size()}")
# if "emb" not in self.lora_name:
# print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
if len(x.size()) != 4:
@@ -265,7 +268,9 @@ def regional_forward(self, x):
# apply mask for LoRA result
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
mask = self.get_mask_for_x(lx)
# logger.info(f"regional {self.lora_name} {self.network.sub_prompt_index} {lx.size()} {mask.size()}")
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
# if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked)
# mask = mask.squeeze(-1)
lx = lx * mask

x = self.org_forward(x)
@@ -514,7 +519,9 @@ def parse_floats(s):
len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
else:
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
logger.warning(
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
)
block_dims = [network_dim] * num_total_blocks

if block_alphas is not None:
@@ -792,17 +799,23 @@ def __init__(
logger.info(f"create LoRA network from weights")
elif block_dims is not None:
logger.info(f"create LoRA network from block_dims")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
logger.info(f"block_dims: {block_dims}")
logger.info(f"block_alphas: {block_alphas}")
if conv_block_dims is not None:
logger.info(f"conv_block_dims: {conv_block_dims}")
logger.info(f"conv_block_alphas: {conv_block_alphas}")
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
if self.conv_lora_dim is not None:
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
logger.info(
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
)

# create module instances
def create_modules(
@@ -929,6 +942,10 @@ def set_multiplier(self, multiplier):
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier

def set_enabled(self, is_enabled):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled

def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
@@ -1116,7 +1133,7 @@ def set_region(self, sub_prompt_index, is_last_network, mask):
for lora in self.text_encoder_loras + self.unet_loras:
lora.set_network(self)

def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None):
self.batch_size = batch_size
self.num_sub_prompts = num_sub_prompts
self.current_size = (height, width)
@@ -1142,6 +1159,13 @@ def resize_add(mh, mw):
resize_add(h, w)
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
resize_add(h + h % 2, w + w % 2)

# deep shrink
if ds_ratio is not None:
hd = int(h * ds_ratio)
wd = int(w * ds_ratio)
resize_add(hd, wd)

h = (h + 1) // 2
w = (w + 1) // 2