From a35d7ef2271aa123ee12172668c5011cb17770f7 Mon Sep 17 00:00:00 2001 From: Jakaline-dev Date: Sun, 26 Mar 2023 05:26:10 +0900 Subject: [PATCH 1/2] Implement XTI --- XTI_hijack.py | 209 ++++++++++++ gen_img_diffusers.py | 124 +++++-- library/train_util.py | 27 +- train_textual_inversion_XTI.py | 586 +++++++++++++++++++++++++++++++++ 4 files changed, 925 insertions(+), 21 deletions(-) create mode 100644 XTI_hijack.py create mode 100644 train_textual_inversion_XTI.py diff --git a/XTI_hijack.py b/XTI_hijack.py new file mode 100644 index 000000000..f39cc8e7e --- /dev/null +++ b/XTI_hijack.py @@ -0,0 +1,209 @@ +import torch +from typing import Union, List, Optional, Dict, Any, Tuple +from diffusers.models.unet_2d_condition import UNet2DConditionOutput + +def unet_forward_XTI(self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.config.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + down_i = 0 + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states[down_i:down_i+2], + ) + down_i += 2 + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6]) + + # 5. up + up_i = 7 + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states[up_i:up_i+3], + upsample_size=upsample_size, + ) + up_i += 3 + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + +def downblock_forward_XTI( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None +): + output_states = () + i = 0 + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample + + output_states += (hidden_states,) + i += 1 + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + +def upblock_forward_XTI( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, +): + i = 0 + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample + + i += 1 + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states \ No newline at end of file diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 38bc86e92..52b80a44a 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -84,6 +84,8 @@ import tools.original_control_net as original_control_net from tools.original_control_net import ControlNetInfo +from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI + # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う @@ -475,6 +477,9 @@ def __init__( # Textual Inversion self.token_replacements = {} + # XTI + self.token_replacements_XTI = {} + # CLIP guidance self.clip_guidance_scale = clip_guidance_scale self.clip_image_guidance_scale = clip_image_guidance_scale @@ -496,15 +501,26 @@ def __init__( def add_token_replacement(self, target_token_id, rep_token_ids): self.token_replacements[target_token_id] = rep_token_ids - def replace_token(self, tokens): + def replace_token(self, tokens, layer=None): new_tokens = [] for token in tokens: if token in self.token_replacements: - new_tokens.extend(self.token_replacements[token]) + replacer_ = self.token_replacements[token] + if layer: + replacer = [] + for r in replacer_: + if r in self.token_replacements_XTI: + replacer.append(self.token_replacements_XTI[r][layer]) + else: + replacer = replacer_ + new_tokens.extend(replacer) else: new_tokens.append(token) return new_tokens + def add_token_replacement_XTI(self, target_token_id, rep_token_ids): + self.token_replacements_XTI[target_token_id] = rep_token_ids + def set_control_nets(self, ctrl_nets): self.control_nets = ctrl_nets @@ -723,14 +739,15 @@ def __call__( " the batch size of `prompt`." ) - text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - **kwargs, - ) + if not args.XTI_embeddings: + text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + **kwargs, + ) if negative_scale is not None: _, real_uncond_embeddings, _ = get_weighted_text_embeddings( @@ -742,11 +759,30 @@ def __call__( **kwargs, ) - if do_classifier_free_guidance: - if negative_scale is None: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - else: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) + if args.XTI_embeddings: + text_embeddings_concat = [] + for layer in ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11']: + text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + layer=layer, + **kwargs, + ) + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings])) + else: + text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])) + text_embeddings = torch.stack(text_embeddings_concat) + else: + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) # CLIP guidanceで使用するembeddingsを取得する if self.clip_guidance_scale > 0: @@ -1562,7 +1598,7 @@ def multiply_range(start_position, multiplier): return res -def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int): +def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int, layer=None): r""" Tokenize a list of prompts and return its tokens with weights of each token. No padding, starting or ending token is included. @@ -1578,7 +1614,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: # tokenize and discard the starting and the ending token token = pipe.tokenizer(word).input_ids[1:-1] - token = pipe.replace_token(token) + token = pipe.replace_token(token, layer=layer) text_token += token # copy the weight by length of token @@ -1694,6 +1730,7 @@ def get_weighted_text_embeddings( skip_parsing: Optional[bool] = False, skip_weighting: Optional[bool] = False, clip_skip=None, + layer=None, **kwargs, ): r""" @@ -1724,11 +1761,11 @@ def get_weighted_text_embeddings( prompt = [prompt] if not skip_parsing: - prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) + prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer) if uncond_prompt is not None: if isinstance(uncond_prompt, str): uncond_prompt = [uncond_prompt] - uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) + uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2, layer=layer) else: prompt_tokens = [ token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids @@ -1961,6 +1998,9 @@ def main(args): # xformers、Hypernetwork対応 if not args.diffusers_xformers: replace_unet_modules(unet, not args.xformers, args.xformers) + diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI + diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI + diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI # tokenizerを読み込む print("loading tokenizer") @@ -2196,6 +2236,50 @@ def __getattr__(self, item): for token_id, embed in zip(token_ids, embeds): token_embeds[token_id] = embed + if args.XTI_embeddings: + XTI_layers = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11'] + token_ids_embeds_XTI = [] + for embeds_file in args.XTI_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + if set(data.keys()) != set(XTI_layers): + raise ValueError("NOT XTI") + embeds = torch.concat(list(data.values())) + num_vectors_per_token = data['MID'].size()[0] + + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens = tokenizer.add_tokens(token_strings) + assert num_added_tokens == num_vectors_per_token, f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + + #if num_vectors_per_token > 1: + pipe.add_token_replacement(token_ids[0], token_ids) + + token_strings_XTI = [] + for layer_name in XTI_layers: + token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings] + tokenizer.add_tokens(token_strings_XTI) + token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI) + token_ids_embeds_XTI.append((token_ids_XTI, embeds)) + for t in token_ids: + t_XTI_dic = {} + for i, layer_name in enumerate(XTI_layers): + t_XTI_dic[layer_name] = t + (i + 1) * num_added_tokens + pipe.add_token_replacement_XTI(t, t_XTI_dic) + + text_encoder.resize_token_embeddings(len(tokenizer)) + token_embeds = text_encoder.get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds_XTI: + for token_id, embed in zip(token_ids, embeds): + token_embeds[token_id] = embed # promptを取得する if args.from_file is not None: print(f"reading prompts from {args.from_file}") @@ -2754,6 +2838,8 @@ def setup_parser() -> argparse.ArgumentParser: help='show metadata of network model / ネットワークモデルのメタデータを表示する') parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*', help='Embeddings files of Textual Inversion / Textual Inversionのembeddings') + parser.add_argument("--XTI_embeddings", type=str, default=None, nargs='*', + help='Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') parser.add_argument("--max_embeddings_multiples", type=int, default=None, help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる') diff --git a/library/train_util.py b/library/train_util.py index 97f5a7028..61466c62c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -391,6 +391,8 @@ def __init__( self.token_padding_disabled = False self.tag_frequency = {} + self.XTI_layers = None + self.token_strings = None self.enable_bucket = False self.bucket_manager: BucketManager = None # not initialized @@ -437,6 +439,10 @@ def set_tag_frequency(self, dir_name, captions): def disable_token_padding(self): self.token_padding_disabled = True + def enable_XTI(self, layers=None, token_strings=None): + self.XTI_layers = layers + self.token_strings = token_strings + def add_replacement(self, str_from, str_to): self.replacements[str_from] = str_to @@ -870,9 +876,22 @@ def __getitem__(self, index): latents_list.append(latents) caption = self.process_caption(subset, image_info.caption) - captions.append(caption) + if self.XTI_layers: + caption_layer = [] + for layer in self.XTI_layers: + token_strings_from = " ".join(self.token_strings) + token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + caption_ = caption.replace(token_strings_from, token_strings_to) + caption_layer.append(caption_) + captions.append(caption_layer) + else: + captions.append(caption) if not self.token_padding_disabled: # this option might be omitted in future - input_ids_list.append(self.get_input_ids(caption)) + if self.XTI_layers: + token_caption = self.get_input_ids(caption_layer) + else: + token_caption = self.get_input_ids(caption) + input_ids_list.append(token_caption) example = {} example["loss_weights"] = torch.FloatTensor(loss_weights) @@ -1273,6 +1292,10 @@ def add_replacement(self, str_from, str_to): # for dataset in self.datasets: # dataset.make_buckets() + def enable_XTI(self, *args, **kwargs): + for dataset in self.datasets: + dataset.enable_XTI(*args, **kwargs) + def cache_latents(self, vae, vae_batch_size=1): for i, dataset in enumerate(self.datasets): print(f"[Dataset {i}]") diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py new file mode 100644 index 000000000..60bf4f7ec --- /dev/null +++ b/train_textual_inversion_XTI.py @@ -0,0 +1,586 @@ +import importlib +import argparse +import gc +import math +import os +import toml + +from tqdm import tqdm +import torch +from accelerate.utils import set_seed +import diffusers +from diffusers import DDPMScheduler + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) + +from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI + +imagenet_templates_small = [ + "a photo of a {}", + "a rendering of a {}", + "a cropped photo of the {}", + "the photo of a {}", + "a photo of a clean {}", + "a photo of a dirty {}", + "a dark photo of the {}", + "a photo of my {}", + "a photo of the cool {}", + "a close-up photo of a {}", + "a bright photo of the {}", + "a cropped photo of a {}", + "a photo of the {}", + "a good photo of the {}", + "a photo of one {}", + "a close-up photo of the {}", + "a rendition of the {}", + "a photo of the clean {}", + "a rendition of a {}", + "a photo of a nice {}", + "a good photo of a {}", + "a photo of the nice {}", + "a photo of the small {}", + "a photo of the weird {}", + "a photo of the large {}", + "a photo of a cool {}", + "a photo of a small {}", +] + +imagenet_style_templates_small = [ + "a painting in the style of {}", + "a rendering in the style of {}", + "a cropped painting in the style of {}", + "the painting in the style of {}", + "a clean painting in the style of {}", + "a dirty painting in the style of {}", + "a dark painting in the style of {}", + "a picture in the style of {}", + "a cool painting in the style of {}", + "a close-up painting in the style of {}", + "a bright painting in the style of {}", + "a cropped painting in the style of {}", + "a good painting in the style of {}", + "a close-up painting in the style of {}", + "a rendition in the style of {}", + "a nice painting in the style of {}", + "a small painting in the style of {}", + "a weird painting in the style of {}", + "a large painting in the style of {}", +] + + +def collate_fn(examples): + return examples[0] + + +def train(args): + if args.output_name is None: + args.output_name = args.token_string + use_template = args.use_object_template or args.use_style_template + + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + + cache_latents = args.cache_latents + + if args.seed is not None: + set_seed(args.seed) + + tokenizer = train_util.load_tokenizer(args) + + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) + + # Convert the init_word to token_id + if args.init_word is not None: + init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) + if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: + print( + f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" + ) + else: + init_token_ids = None + + # add new word to tokenizer, count is num_vectors_per_token + token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == args.num_vectors_per_token + ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + print(f"tokens are added: {token_ids}") + assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" + assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" + + token_strings_XTI = [] + XTI_layers = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11'] + for layer_name in XTI_layers: + token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings] + + tokenizer.add_tokens(token_strings_XTI) + token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI) + print(f"tokens are added (XTI): {token_ids_XTI}") + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + if init_token_ids is not None: + for i, token_id in enumerate(token_ids_XTI): + token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]] + # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + + # load weights + if args.weights is not None: + embeddings = load_weights(args.weights) + assert len(token_ids) == len( + embeddings + ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" + # print(token_ids, embeddings.size()) + for token_id, embedding in zip(token_ids_XTI, embeddings): + token_embeds[token_id] = embedding + # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + print(f"weighs loaded") + + print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + use_dreambooth_method = args.in_json is None + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } + else: + print("Train with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings) + + # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 + if use_template: + print("use template for training captions. is object: {args.use_object_template}") + templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small + replace_to = " ".join(token_strings) + captions = [] + for tmpl in templates: + captions.append(tmpl.format(replace_to)) + train_dataset_group.add_replacement("", captions) + + if args.num_vectors_per_token > 1: + prompt_replacement = [args.token_string, replace_to] + else: + prompt_replacement = None + else: + if args.num_vectors_per_token > 1: + replace_to = " ".join(token_strings) + train_dataset_group.add_replacement(args.token_string, replace_to) + prompt_replacement = [args.token_string, replace_to] + else: + prompt_replacement = None + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, show_input_ids=True) + return + if len(train_dataset_group) == 0: + print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI + diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI + diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + trainable_params = text_encoder.get_input_embeddings().parameters() + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collate_fn, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * len(train_dataloader) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # acceleratorがなんかよろしくやってくれるらしい + text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, lr_scheduler + ) + + index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] + # print(len(index_no_updates), torch.sum(index_no_updates)) + orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() + + # Freeze all parameters except for the token embeddings in text encoder + text_encoder.requires_grad_(True) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) + + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + unet.train() + else: + unet.eval() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + text_encoder.to(weight_dtype) + + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / 学習開始") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion") + + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset_group.set_current_epoch(epoch + 1) + + text_encoder.train() + + loss_total = 0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(text_encoder): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + # weight_dtype) use float instead of fp16/bf16 because text encoder is float + encoder_hidden_states = torch.stack([train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype) for s in torch.split(input_ids, 1, dim=1)]) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = text_encoder.get_input_embeddings().parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Let's make sure we don't update any embedding weights besides the newly added token + with torch.no_grad(): + unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ + index_no_updates + ] + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + ) + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) + + loss_total += current_loss + avr_loss = loss_total / (step + 1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() + + if args.save_every_n_epochs is not None: + model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + + def save_func(): + ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") + save_weights(ckpt_file, updated_embs, save_dtype) + + def remove_old_func(old_epoch_no): + old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + ) + + # end of epoch + + is_main_process = accelerator.is_main_process + if is_main_process: + text_encoder = unwrap_model(text_encoder) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone() + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + + model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + ckpt_name = model_name + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"save trained model to {ckpt_file}") + save_weights(ckpt_file, updated_embs, save_dtype) + print("model saved.") + + +def save_weights(file, updated_embs, save_dtype): + updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1]) + updated_embs = updated_embs.chunk(16) + XTI_layers = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11'] + state_dict = {} + for i, layer_name in enumerate(XTI_layers): + state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype) + + # if save_dtype is not None: + # for key in list(state_dict.keys()): + # v = state_dict[key] + # v = v.detach().clone().to("cpu").to(save_dtype) + # state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, file) + else: + torch.save(state_dict, file) # can be loaded in Web UI + + +def load_weights(file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + data = load_file(file) + else: + raise ValueError(f"NOT XTI: {file}") + + if len(data.values()) != 16: + raise ValueError(f"NOT XTI: {file}") + + emb = torch.concat([x for x in data.values()]) + + return emb + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, False) + train_util.add_training_arguments(parser, True) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + + parser.add_argument( + "--save_model_as", + type=str, + default="pt", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", + ) + + parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み") + parser.add_argument( + "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" + ) + parser.add_argument( + "--token_string", + type=str, + default=None, + help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", + ) + parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") + parser.add_argument( + "--use_object_template", + action="store_true", + help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する", + ) + parser.add_argument( + "--use_style_template", + action="store_true", + help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) From 24e3d4b4642673cebd552f1a9cbd3d99eac969a2 Mon Sep 17 00:00:00 2001 From: Jakaline-dev Date: Thu, 30 Mar 2023 02:20:04 +0900 Subject: [PATCH 2/2] disabled sampling (for now) --- gen_img_diffusers.py | 17 ++++++------- train_textual_inversion_XTI.py | 46 +++++++++++++++++++++------------- 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index b562d097d..cd0be71b4 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -781,21 +781,19 @@ def __call__( text_embeddings_concat = [] for layer in ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11']: text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - layer=layer, - **kwargs, + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + layer=layer, + **kwargs, ) if do_classifier_free_guidance: if negative_scale is None: text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings])) else: text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])) - - text_embeddings = torch.stack(text_embeddings_concat) else: if do_classifier_free_guidance: @@ -803,7 +801,6 @@ def __call__( text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) else: text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( pipe=self, prompt=prompt, diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 60bf4f7ec..8d6ff430e 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -4,6 +4,7 @@ import math import os import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -17,7 +18,8 @@ ConfigSanitizer, BlueprintGenerator, ) - +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI imagenet_templates_small = [ @@ -73,10 +75,6 @@ ] -def collate_fn(examples): - return examples[0] - - def train(args): if args.output_name is None: args.output_name = args.token_string @@ -195,6 +193,10 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings) + current_epoch = Value('i',0) + current_step = Value('i',0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch,current_step, ds_for_collater) # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: @@ -207,14 +209,14 @@ def train(args): train_dataset_group.add_replacement("", captions) if args.num_vectors_per_token > 1: - prompt_replacement = [args.token_string, replace_to] + prompt_replacement = (args.token_string, replace_to) else: prompt_replacement = None else: if args.num_vectors_per_token > 1: replace_to = " ".join(token_strings) train_dataset_group.add_replacement(args.token_string, replace_to) - prompt_replacement = [args.token_string, replace_to] + prompt_replacement = (args.token_string, replace_to) else: prompt_replacement = None @@ -264,16 +266,19 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * len(train_dataloader) + args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -345,12 +350,14 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) + current_epoch.value = epoch+1 text_encoder.train() loss_total = 0 + for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(text_encoder): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: @@ -391,6 +398,9 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights @@ -416,10 +426,10 @@ def train(args): if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - - train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement - ) + # TODO: fix sample_images + # train_util.sample_images( + # accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + # ) current_loss = loss.detach().item() if args.logging_dir is not None: @@ -466,9 +476,10 @@ def remove_old_func(old_epoch_no): if saving and args.save_state: train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) - train_util.sample_images( - accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement - ) + # TODO: fix sample_images + # train_util.sample_images( + # accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + # ) # end of epoch @@ -543,6 +554,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument( "--save_model_as",