Skip to content

Commit

Permalink
Merge branch 'master' into sd
Browse files Browse the repository at this point in the history
  • Loading branch information
gesen2egee committed Feb 8, 2024
2 parents c1f95d7 + 0412be4 commit 9d58b3c
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 10 deletions.
45 changes: 42 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,48 @@
logs
__pycache__
wd14_tagger_model
# Python
venv
__pycache__
*.egg-info
build
wd14_tagger_model

# IDE and Editor specific
.vscode

# CUDNN for Windows
cudnn_windows

# Cache and temporary files
.cache
.DS_Store

# Scripts and executables
locon
gui-user.bat
gui-user.ps1

# Version control
SmilingWolf
wandb

# Setup and logs
setup.log
logs

# Miscellaneous
uninstall.txt

# Test files
test/output
test/log*
test/*.json
test/ft

# Temporary requirements
requirements_tmp_for_setup.txt

# Version specific
0.13.3

*.npz
*.bat
presets/*/user_presets/*
8 changes: 7 additions & 1 deletion library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class BaseSubsetParams:
caption_separator: str = (",",)
keep_tokens: int = 0
keep_tokens_separator: str = (None,)
use_object_template: bool = False
use_style_template: bool = False
color_aug: bool = False
flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None
Expand All @@ -65,7 +67,7 @@ class BaseSubsetParams:
caption_suffix: Optional[str] = None
caption_dropout_rate: float = 0.0
caption_dropout_every_n_epochs: int = 0
caption_tag_dropout_rate: float = 0.0
caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: float = 0

Expand Down Expand Up @@ -178,6 +180,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"shuffle_caption": bool,
"keep_tokens": int,
"keep_tokens_separator": str,
"use_object_template": bool,
"use_style_template": bool,
"token_warmup_min": int,
"token_warmup_step": Any(float, int),
"caption_prefix": str,
Expand Down Expand Up @@ -501,6 +505,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
use_object_template: bool = {subset.use_object_template}
use_style_template: bool = {subset.use_style_template}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
Expand Down
87 changes: 85 additions & 2 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,59 @@

IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]

imagenet_templates_small = [
"a photo of a ",
"a rendering of a ",
"a cropped photo of the ",
"the photo of a ",
"a photo of a clean ",
"a photo of a dirty ",
"a dark photo of the ",
"a photo of my ",
"a photo of the cool ",
"a close-up photo of a ",
"a bright photo of the ",
"a cropped photo of a ",
"a photo of the ",
"a good photo of the ",
"a photo of one ",
"a close-up photo of the ",
"a rendition of the ",
"a photo of the clean ",
"a rendition of a ",
"a photo of a nice ",
"a good photo of a ",
"a photo of the nice ",
"a photo of the small ",
"a photo of the weird ",
"a photo of the large ",
"a photo of a cool ",
"a photo of a small ",
]

imagenet_style_templates_small = [
"a painting in the style of ",
"a rendering in the style of ",
"a cropped painting in the style of ",
"the painting in the style of ",
"a clean painting in the style of ",
"a dirty painting in the style of ",
"a dark painting in the style of ",
"a picture in the style of ",
"a cool painting in the style of ",
"a close-up painting in the style of ",
"a bright painting in the style of ",
"a cropped painting in the style of ",
"a good painting in the style of ",
"a close-up painting in the style of ",
"a rendition in the style of ",
"a nice painting in the style of ",
"a small painting in the style of ",
"a weird painting in the style of ",
"a large painting in the style of ",
]


try:
import pillow_avif

Expand Down Expand Up @@ -360,6 +413,8 @@ def __init__(
caption_separator: str,
keep_tokens: int,
keep_tokens_separator: str,
use_object_template: bool,
use_style_template: bool,
color_aug: bool,
flip_aug: bool,
face_crop_aug_range: Optional[Tuple[float, float]],
Expand All @@ -378,6 +433,8 @@ def __init__(
self.caption_separator = caption_separator
self.keep_tokens = keep_tokens
self.keep_tokens_separator = keep_tokens_separator
self.use_object_template = use_object_template
self.use_style_template = use_style_template
self.color_aug = color_aug
self.flip_aug = flip_aug
self.face_crop_aug_range = face_crop_aug_range
Expand Down Expand Up @@ -406,6 +463,8 @@ def __init__(
caption_separator: str,
keep_tokens,
keep_tokens_separator,
use_object_template,
use_style_template,
color_aug,
flip_aug,
face_crop_aug_range,
Expand All @@ -427,6 +486,8 @@ def __init__(
caption_separator,
keep_tokens,
keep_tokens_separator,
use_object_template,
use_style_template,
color_aug,
flip_aug,
face_crop_aug_range,
Expand Down Expand Up @@ -462,6 +523,8 @@ def __init__(
caption_separator,
keep_tokens,
keep_tokens_separator,
use_object_template,
use_style_template,
color_aug,
flip_aug,
face_crop_aug_range,
Expand All @@ -483,6 +546,8 @@ def __init__(
caption_separator,
keep_tokens,
keep_tokens_separator,
use_object_template,
use_style_template,
color_aug,
flip_aug,
face_crop_aug_range,
Expand Down Expand Up @@ -515,6 +580,8 @@ def __init__(
caption_separator,
keep_tokens,
keep_tokens_separator,
use_object_template,
use_style_template,
color_aug,
flip_aug,
face_crop_aug_range,
Expand All @@ -536,6 +603,8 @@ def __init__(
caption_separator,
keep_tokens,
keep_tokens_separator,
use_object_template,
use_style_template,
color_aug,
flip_aug,
face_crop_aug_range,
Expand Down Expand Up @@ -659,7 +728,6 @@ def process_caption(self, subset: BaseSubset, caption):
caption = subset.caption_prefix + " " + caption
if subset.caption_suffix:
caption = caption + " " + subset.caption_suffix

# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
is_drop_out = (
Expand Down Expand Up @@ -688,7 +756,10 @@ def process_caption(self, subset: BaseSubset, caption):
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = tokens[subset.keep_tokens :]

if subset.use_object_template or subset.use_style_template:
imagenet_templates = imagenet_templates_small if subset.use_object_template else imagenet_style_templates_small
imagenet_template = [random.choice(imagenet_templates)]
caption = imagenet_template + fixed_tokens
if subset.token_warmup_step < 1: # 初回に上書きする
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
Expand Down Expand Up @@ -1791,6 +1862,8 @@ def __init__(
subset.caption_separator,
subset.keep_tokens,
subset.keep_tokens_separator,
subset.use_object_template,
subset.use_style_template,
subset.color_aug,
subset.flip_aug,
subset.face_crop_aug_range,
Expand Down Expand Up @@ -3355,6 +3428,16 @@ def add_dataset_arguments(
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens."
+ " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。",
)
parser.add_argument(
"--use_object_template",
action="store_true",
help="prefix default templates for object for caption text / キャプションは使わずデフォルトの物体用テンプレートで学習する",
)
parser.add_argument(
"--use_style_template",
action="store_true",
help="prefix default templates for stype for caption text / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
)
parser.add_argument(
"--caption_prefix",
type=str,
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ huggingface-hub==0.20.1
# for loading Diffusers' SDXL
invisible-watermark==0.2.0
lion-pytorch==0.0.6
lycoris_lora==2.0.2
lycoris_lora==2.1.0.dev9
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12
Expand Down Expand Up @@ -42,5 +42,6 @@ transformers==4.36.2
voluptuous==0.13.1
wandb==0.15.11
scipy==1.11.4
beartype
# for kohya_ss library
-e . # no_verify leave this to specify not checking this a verification stage
15 changes: 12 additions & 3 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,8 +1034,7 @@ def remove_model(old_ckpt_name):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)



# Let's make sure we don't update any embedding weights besides the added pivots
if args.continue_inversion:
with torch.no_grad():
Expand All @@ -1052,7 +1051,7 @@ def remove_model(old_ckpt_name):
emb_token_ids = embedding_to_token_ids[emb_name]
updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[emb_token_ids].data.detach().clone()
embeddings_map[emb_name] = updated_embs

if args.enable_ema:
for i, e in enumerate(emas):
if args.ema_type == "post-hoc" and ((e.step + 1) % e.post_hoc_snapshot_every) == 0 and e.step != 0:
Expand Down Expand Up @@ -1160,6 +1159,16 @@ def remove_model(old_ckpt_name):
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, embeddings_map, force_sync_upload=True)

if args.enable_ema and args.ema_type == 'traditional':
# save directly
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(os.path.splitext(ckpt_name)[0] + "-EMA" + os.path.splitext(ckpt_name)[1], emas[0].ema_model, global_step, num_train_epochs, force_sync_upload=True)

## save EMA - copy and save
#emas[0].copy_params_from_ema_to_model()
#ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
#save_model(os.path.splitext(ckpt_name)[0] + "-EMA_" + os.path.splitext(ckpt_name)[1], network, global_step, num_train_epochs, force_sync_upload=True)

if args.enable_ema and args.ema_type == 'traditional':
# save directly
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
Expand Down

0 comments on commit 9d58b3c

Please sign in to comment.