From 5b19bda85c2ce01e4a1c7f324b7ef14bffed3315 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 12:35:46 -0500 Subject: [PATCH 001/348] Add validation loss --- library/train_util.py | 4 ++ train_network.py | 117 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 120 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index cc9ac4555..e26f39799 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4736,6 +4736,10 @@ def __call__(self, examples): else: dataset = self.dataset + # If we split a dataset we will get a Subset + if type(dataset) is torch.utils.data.Subset: + dataset = dataset.dataset + # set epoch and step dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) diff --git a/train_network.py b/train_network.py index d50916b74..58767b6f7 100644 --- a/train_network.py +++ b/train_network.py @@ -345,8 +345,21 @@ def train(self, args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + if args.validation_ratio > 0.0: + train_ratio = 1 - args.validation_ratio + validation_ratio = args.validation_ratio + train, val = torch.utils.data.random_split( + train_dataset_group, + [train_ratio, validation_ratio] + ) + print(f"split dataset by ratio: train {train_ratio}, validation {validation_ratio}") + print(f"train images: {len(train)}, validation images: {len(val)}") + else: + train = train_dataset_group + val = [] + train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, + train, batch_size=1, shuffle=True, collate_fn=collator, @@ -354,6 +367,15 @@ def train(self, args): persistent_workers=args.persistent_data_loader_workers, ) + val_dataloader = torch.utils.data.DataLoader( + val, + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( @@ -711,6 +733,8 @@ def train(self, args): ) loss_recorder = train_util.LossRecorder() + val_loss_recorder = train_util.LossRecorder() + del train_dataset_group # callback for step start @@ -752,6 +776,8 @@ def remove_model(old_ckpt_name): network.on_epoch_start(text_encoder, unet) + # TRAINING + for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(network): @@ -877,6 +903,87 @@ def remove_model(old_ckpt_name): if global_step >= args.max_train_steps: break + # VALIDATION + + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + + with torch.no_grad(): + for val_step, batch in enumerate(val_dataloader): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(device=accelerator.device, dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * self.vae_scale_factor + b_size = latents.shape[0] + + # Get the text embedding for conditioning + if args.weighted_captions: + text_encoder_conds = get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + text_encoder_conds = self.get_text_cond( + args, accelerator, batch, tokenizers, text_encoders, weight_dtype + ) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight + + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + current_loss = loss.detach().item() + + val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + + if len(val_dataloader) > 0: + avr_loss: float = val_loss_recorder.moving_average + + if args.logging_dir is not None: + logs = {"loss/validation": avr_loss} + accelerator.log(logs, step=epoch + 1) + + if args.logging_dir is not None: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) @@ -999,6 +1106,14 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + + parser.add_argument( + "--validation_ratio", + type=float, + default=0.0, + help="Ratio for validation images out of the training dataset" + ) + return parser From 33c311ed19821c9be7094ba89371777d7478b028 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 12:37:37 -0500 Subject: [PATCH 002/348] new ratio code --- train_network.py | 48 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/train_network.py b/train_network.py index 58767b6f7..967c95fb4 100644 --- a/train_network.py +++ b/train_network.py @@ -345,10 +345,48 @@ def train(self, args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + def get_indices_without_reg(dataset: torch.utils.data.Dataset): + return [id for id, (key, item) in enumerate(dataset.image_data.items()) if item.is_reg is False] + + from typing import Sequence, Union + from torch._utils import _accumulate + import warnings + from torch.utils.data.dataset import Subset + + def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, float]]): + indices = get_indices_without_reg(dataset) + random.shuffle(indices) + + subset_lengths = [] + + for i, frac in enumerate(lengths): + if frac < 0 or frac > 1: + raise ValueError(f"Fraction at index {i} is not between 0 and 1") + n_items_in_split = int(math.floor(len(indices) * frac)) + subset_lengths.append(n_items_in_split) + + remainder = len(indices) - sum(subset_lengths) + + for i in range(remainder): + idx_to_add_at = i % len(subset_lengths) + subset_lengths[idx_to_add_at] += 1 + + lengths = subset_lengths + for i, length in enumerate(lengths): + if length == 0: + warnings.warn(f"Length of split at index {i} is 0. " + f"This might result in an empty dataset.") + + if sum(lengths) != len(indices): + raise ValueError("Sum of input lengths does not equal the length of the input dataset!") + + return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)] + + if args.validation_ratio > 0.0: train_ratio = 1 - args.validation_ratio validation_ratio = args.validation_ratio - train, val = torch.utils.data.random_split( + train, val = random_split( train_dataset_group, [train_ratio, validation_ratio] ) @@ -358,6 +396,8 @@ def train(self, args): train = train_dataset_group val = [] + + train_dataloader = torch.utils.data.DataLoader( train, batch_size=1, @@ -898,7 +938,7 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) - accelerator.log(logs, step=global_step) + accelerator.log(logs) if global_step >= args.max_train_steps: break @@ -973,13 +1013,11 @@ def remove_model(old_ckpt_name): loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし current_loss = loss.detach().item() - val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) if len(val_dataloader) > 0: - avr_loss: float = val_loss_recorder.moving_average - if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation": avr_loss} accelerator.log(logs, step=epoch + 1) From 3de9e6c443037abf99832d1be60f4fc9c0d67b8c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 01:45:23 -0500 Subject: [PATCH 003/348] Add validation split of datasets --- library/config_util.py | 145 ++++++++++++++++++++++++++--------------- library/train_util.py | 26 ++++++++ train_network.py | 67 ++++--------------- 3 files changed, 128 insertions(+), 110 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index e8e0fda7c..1bf7ed955 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -85,6 +85,8 @@ class BaseDatasetParams: max_token_length: int = None resolution: Optional[Tuple[int, int]] = None debug_dataset: bool = False + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -200,6 +202,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "enable_bucket": bool, "max_bucket_reso": int, "min_bucket_reso": int, + "validation_seed": int, + "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), } @@ -427,64 +431,89 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) - # print info - info = "" - for i, dataset in enumerate(datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent(f"""\ - [Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - """) - - if dataset.enable_bucket: - info += indent(dedent(f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n"""), " ") + val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.params.validation_split <= 0.0: + continue + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset else: - info += "\n" - - for j, subset in enumerate(dataset.subsets): - info += indent(dedent(f"""\ - [Subset {j} of Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - shuffle_caption: {subset.shuffle_caption} - keep_tokens: {subset.keep_tokens} - caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} - caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} - caption_prefix: {subset.caption_prefix} - caption_suffix: {subset.caption_suffix} - color_aug: {subset.color_aug} - flip_aug: {subset.flip_aug} - face_crop_aug_range: {subset.face_crop_aug_range} - random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - """), " ") - - if is_dreambooth: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) + val_datasets.append(dataset) + + # print info + def print_info(_datasets): + info = "" + for i, dataset in enumerate(_datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent(f"""\ + [Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + """) + + if dataset.enable_bucket: info += indent(dedent(f"""\ - is_reg: {subset.is_reg} - class_tokens: {subset.class_tokens} - caption_extension: {subset.caption_extension} - \n"""), " ") - elif not is_controlnet: + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n"""), " ") + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): info += indent(dedent(f"""\ - metadata_file: {subset.metadata_file} - \n"""), " ") - - print(info) + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + """), " ") + + if is_dreambooth: + info += indent(dedent(f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n"""), " ") + elif not is_controlnet: + info += indent(dedent(f"""\ + metadata_file: {subset.metadata_file} + \n"""), " ") + + print(info) + + print_info(datasets) + + if len(val_datasets) > 0: + print("Validation dataset") + print_info(val_datasets) # make buckets first because it determines the length of dataset # and set the same seed for all datasets @@ -494,7 +523,15 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset.make_buckets() dataset.set_seed(seed) - return DatasetGroup(datasets) + for i, dataset in enumerate(val_datasets): + print(f"[Validation Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + return ( + DatasetGroup(datasets), + DatasetGroup(val_datasets) if val_datasets else None + ) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): diff --git a/library/train_util.py b/library/train_util.py index e26f39799..ba37ec13d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -123,6 +123,22 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +def split_train_val(paths, is_train, validation_split, validation_seed): + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + if is_train: + return paths[0:math.ceil(len(paths) * (1 - validation_split))] + else: + return paths[len(paths) - round(len(paths) * validation_split):] + + class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: self.image_key: str = image_key @@ -1314,6 +1330,7 @@ class DreamBoothDataset(BaseDataset): def __init__( self, subsets: Sequence[DreamBoothSubset], + is_train: bool, batch_size: int, tokenizer, max_token_length, @@ -1324,12 +1341,18 @@ def __init__( bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, + validation_split: float, + validation_seed: Optional[int], debug_dataset, ) -> None: super().__init__(tokenizer, max_token_length, resolution, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" + self.is_train = is_train + self.validation_split = validation_split + self.validation_seed = validation_seed + self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight @@ -1382,6 +1405,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset): return [], [] img_paths = glob_images(subset.image_dir, "*") + + if self.validation_split > 0.0: + img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う diff --git a/train_network.py b/train_network.py index 967c95fb4..97ecfe7be 100644 --- a/train_network.py +++ b/train_network.py @@ -189,10 +189,11 @@ def train(self, args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + val_dataset_group = None # placeholder until validation dataset supported for arbitrary current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -212,6 +213,10 @@ def train(self, args): assert ( train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + if val_dataset_group is not None: + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" self.assert_extra_args(args, train_dataset_group) @@ -264,6 +269,9 @@ def train(self, args): vae.eval() with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + if val_dataset_group is not None: + print("Cache validation latents...") + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -345,61 +353,8 @@ def train(self, args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - def get_indices_without_reg(dataset: torch.utils.data.Dataset): - return [id for id, (key, item) in enumerate(dataset.image_data.items()) if item.is_reg is False] - - from typing import Sequence, Union - from torch._utils import _accumulate - import warnings - from torch.utils.data.dataset import Subset - - def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, float]]): - indices = get_indices_without_reg(dataset) - random.shuffle(indices) - - subset_lengths = [] - - for i, frac in enumerate(lengths): - if frac < 0 or frac > 1: - raise ValueError(f"Fraction at index {i} is not between 0 and 1") - n_items_in_split = int(math.floor(len(indices) * frac)) - subset_lengths.append(n_items_in_split) - - remainder = len(indices) - sum(subset_lengths) - - for i in range(remainder): - idx_to_add_at = i % len(subset_lengths) - subset_lengths[idx_to_add_at] += 1 - - lengths = subset_lengths - for i, length in enumerate(lengths): - if length == 0: - warnings.warn(f"Length of split at index {i} is 0. " - f"This might result in an empty dataset.") - - if sum(lengths) != len(indices): - raise ValueError("Sum of input lengths does not equal the length of the input dataset!") - - return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)] - - - if args.validation_ratio > 0.0: - train_ratio = 1 - args.validation_ratio - validation_ratio = args.validation_ratio - train, val = random_split( - train_dataset_group, - [train_ratio, validation_ratio] - ) - print(f"split dataset by ratio: train {train_ratio}, validation {validation_ratio}") - print(f"train images: {len(train)}, validation images: {len(val)}") - else: - train = train_dataset_group - val = [] - - - train_dataloader = torch.utils.data.DataLoader( - train, + train_dataset_group, batch_size=1, shuffle=True, collate_fn=collator, @@ -408,7 +363,7 @@ def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, ) val_dataloader = torch.utils.data.DataLoader( - val, + val_dataset_group if val_dataset_group is not None else [], shuffle=False, batch_size=1, collate_fn=collator, From a93c524b3a0e5c80a58c1317211dec93b6c137a7 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 02:07:39 -0500 Subject: [PATCH 004/348] Update args to validation_seed and validation_split --- train_network.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 97ecfe7be..f9e5debdb 100644 --- a/train_network.py +++ b/train_network.py @@ -1099,12 +1099,17 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) - parser.add_argument( - "--validation_ratio", + "--validation_seed", + type=int, + default=None, + help="Validation seed" + ) + parser.add_argument( + "--validation_split", type=float, default=0.0, - help="Ratio for validation images out of the training dataset" + help="Split for validation images out of the training dataset" ) return parser From c89252101e8e8bd74cb3ab09ae33b548fd828e15 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 16:27:36 -0500 Subject: [PATCH 005/348] Add process_batch for train_network --- train_network.py | 211 ++++++++++++++++++----------------------------- 1 file changed, 82 insertions(+), 129 deletions(-) diff --git a/train_network.py b/train_network.py index f9e5debdb..387b94b1c 100644 --- a/train_network.py +++ b/train_network.py @@ -130,6 +130,75 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * self.vae_scale_factor + b_size = latents.shape[0] + + with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + text_encoder_conds = get_weighted_text_embeddings( + tokenizers[0], + text_encoders[0], + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + text_encoder_conds = self.get_text_cond( + args, accelerator, batch, tokenizers, text_encoders, weight_dtype + ) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + + # Predict the noise residual + with torch.set_grad_enabled(is_train), accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + return loss + + def train(self, args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -777,71 +846,8 @@ def remove_model(old_ckpt_name): current_step.value = global_step with accelerator.accumulate(network): on_step_start(text_encoder, unet) - - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) - latents = latents * self.vae_scale_factor - b_size = latents.shape[0] - - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): - # Get the text embedding for conditioning - if args.weighted_captions: - text_encoder_conds = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - text_encoder_conds = self.get_text_cond( - args, accelerator, batch, tokenizers, text_encoders, weight_dtype - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + is_train = True + loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=train_text_encoder) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -893,7 +899,7 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) - accelerator.log(logs) + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break @@ -905,80 +911,27 @@ def remove_model(old_ckpt_name): with torch.no_grad(): for val_step, batch in enumerate(val_dataloader): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(device=accelerator.device, dtype=vae_dtype)).latent_dist.sample() - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) - latents = latents * self.vae_scale_factor - b_size = latents.shape[0] - - # Get the text embedding for conditioning - if args.weighted_captions: - text_encoder_conds = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - text_encoder_conds = self.get_text_cond( - args, accelerator, batch, tokenizers, text_encoders, weight_dtype - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight - - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + is_train = False + loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_current": current_loss} + accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + if len(val_dataloader) > 0: if args.logging_dir is not None: avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation": avr_loss} + logs = {"loss/validation_average": avr_loss} accelerator.log(logs, step=epoch + 1) if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} + # logs = {"loss/epoch": loss_recorder.moving_average} + logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() From e545fdfd9affabff83f8bd2e7680369bb34dd301 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 16:56:36 -0500 Subject: [PATCH 006/348] Removed/cleanup a line --- train_network.py | 1 - 1 file changed, 1 deletion(-) diff --git a/train_network.py b/train_network.py index 387b94b1c..a4125e9f2 100644 --- a/train_network.py +++ b/train_network.py @@ -930,7 +930,6 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: - # logs = {"loss/epoch": loss_recorder.moving_average} logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) From 9c591bdb12ce663b3fe9e91c0963d2cf71461bad Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 16:58:20 -0500 Subject: [PATCH 007/348] Remove unnecessary subset line from collate --- library/train_util.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ba37ec13d..1979207b0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4762,10 +4762,6 @@ def __call__(self, examples): else: dataset = self.dataset - # If we split a dataset we will get a Subset - if type(dataset) is torch.utils.data.Subset: - dataset = dataset.dataset - # set epoch and step dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) From 569ca72fc4cda2f4ce30e43b1c62989e79e3c3b3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 7 Nov 2023 11:59:30 -0500 Subject: [PATCH 008/348] Set grad enabled if is_train and train_text_encoder We only want to be enabling grad if we are training. --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index a4125e9f2..edd3ff944 100644 --- a/train_network.py +++ b/train_network.py @@ -145,7 +145,7 @@ def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, n latents = latents * self.vae_scale_factor b_size = latents.shape[0] - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings( From b558a5b73d07a7e15ad90d9d15c2b55c5d2b3d61 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 04:37:16 +0800 Subject: [PATCH 009/348] val --- library/config_util.py | 176 ++++++++++++++++++++++------------------- library/train_util.py | 22 ++++++ train_network.py | 135 ++++++++++++++++++++++++++++--- 3 files changed, 241 insertions(+), 92 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index fc4b36175..17fc17818 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -98,7 +98,8 @@ class BaseDatasetParams: resolution: Optional[Tuple[int, int]] = None network_multiplier: float = 1.0 debug_dataset: bool = False - + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -109,8 +110,7 @@ class DreamBoothDatasetParams(BaseDatasetParams): bucket_reso_steps: int = 64 bucket_no_upscale: bool = False prior_loss_weight: float = 1.0 - - + @dataclass class FineTuningDatasetParams(BaseDatasetParams): batch_size: int = 1 @@ -222,8 +222,11 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "enable_bucket": bool, "max_bucket_reso": int, "min_bucket_reso": int, + "validation_seed": int, + "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, + } # options handled by argparse but not handled by user config @@ -460,100 +463,107 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) - # print info - info = "" - for i, dataset in enumerate(datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent( - f"""\ - [Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - network_multiplier: {dataset.network_multiplier} - """ - ) + val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.params.validation_split <= 0.0: + continue + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) + val_datasets.append(dataset) + + def print_info(_datasets): + info = "" + for i, dataset in enumerate(_datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent(f"""\ + [Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + """) if dataset.enable_bucket: - info += indent( - dedent( - f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n""" - ), - " ", - ) + info += indent(dedent(f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n"""), " ") else: info += "\n" - for j, subset in enumerate(dataset.subsets): - info += indent( - dedent( - f"""\ - [Subset {j} of Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - shuffle_caption: {subset.shuffle_caption} - keep_tokens: {subset.keep_tokens} - keep_tokens_separator: {subset.keep_tokens_separator} - caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} - caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} - caption_prefix: {subset.caption_prefix} - caption_suffix: {subset.caption_suffix} - color_aug: {subset.color_aug} - flip_aug: {subset.flip_aug} - face_crop_aug_range: {subset.face_crop_aug_range} - random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - """ - ), - " ", - ) - - if is_dreambooth: - info += indent( - dedent( - f"""\ - is_reg: {subset.is_reg} - class_tokens: {subset.class_tokens} - caption_extension: {subset.caption_extension} - \n""" - ), - " ", - ) - elif not is_controlnet: - info += indent( - dedent( - f"""\ - metadata_file: {subset.metadata_file} - \n""" - ), - " ", - ) - - logger.info(f'{info}') - + info += indent(dedent(f"""\ + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + """), " ") + + if is_dreambooth: + info += indent(dedent(f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n"""), " ") + elif not is_controlnet: + info += indent(dedent(f"""\ + metadata_file: {subset.metadata_file} + \n"""), " ") + + print(info) + + print_info(datasets) + + if len(val_datasets) > 0: + print("Validation dataset") + print_info(val_datasets) + # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): - logger.info(f"[Dataset {i}]") + print(f"[Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + for i, dataset in enumerate(val_datasets): + print(f"[Validation Dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) - return DatasetGroup(datasets) - - + return ( + DatasetGroup(datasets), + DatasetGroup(val_datasets) if val_datasets else None + ) + def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): def extract_dreambooth_params(name: str) -> Tuple[int, str]: tokens = name.split("_") diff --git a/library/train_util.py b/library/train_util.py index d2b69edb5..753539e04 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -134,6 +134,20 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +def split_train_val(paths, is_train, validation_split, validation_seed): + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + if is_train: + return paths[0:math.ceil(len(paths) * (1 - validation_split))] + else: + return paths[len(paths) - round(len(paths) * validation_split):] class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: @@ -1360,6 +1374,7 @@ class DreamBoothDataset(BaseDataset): def __init__( self, subsets: Sequence[DreamBoothSubset], + is_train: bool, batch_size: int, tokenizer, max_token_length, @@ -1371,12 +1386,17 @@ def __init__( bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, + validation_split: float, + validation_seed: Optional[int], debug_dataset: bool, ) -> None: super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" + self.is_train = is_train + self.validation_split = validation_split + self.validation_seed = validation_seed self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight @@ -1429,6 +1449,8 @@ def load_dreambooth_dir(subset: DreamBoothSubset): return [], [] img_paths = glob_images(subset.image_dir, "*") + if self.validation_split > 0.0: + img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う diff --git a/train_network.py b/train_network.py index e0fa69458..db7000e82 100644 --- a/train_network.py +++ b/train_network.py @@ -136,6 +136,67 @@ def all_reduce_network(self, accelerator, network): def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): + + total_loss = 0.0 + timesteps_list = [10, 350, 500, 650, 990] + + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * self.vae_scale_factor + b_size = latents.shape[0] + + with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + text_encoder_conds = get_weighted_text_embeddings( + tokenizers[0], + text_encoders[0], + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + text_encoder_conds = self.get_text_cond( + args, accelerator, batch, tokenizers, text_encoders, weight_dtype + ) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, _ = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + for timesteps in timesteps_list: + # Predict the noise residual + with torch.set_grad_enabled(is_train), accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + total_loss += loss + + average_loss = total_loss / len(timesteps_list) + return average_loss + def train(self, args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -196,11 +257,12 @@ def train(self, args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) - + val_dataset_group = None # placeholder until validation dataset supported for arbitrary + current_epoch = Value("i", 0) current_step = Value("i", 0) ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None @@ -219,7 +281,11 @@ def train(self, args): assert ( train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - + if val_dataset_group is not None: + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + self.assert_extra_args(args, train_dataset_group) # acceleratorを準備する @@ -271,6 +337,9 @@ def train(self, args): vae.eval() with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + if val_dataset_group is not None: + print("Cache validation latents...") + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -360,6 +429,15 @@ def train(self, args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) + + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -707,6 +785,8 @@ def train(self, args): ) loss_recorder = train_util.LossRecorder() + val_loss_recorder = train_util.LossRecorder() + del train_dataset_group # callback for step start @@ -755,7 +835,8 @@ def remove_model(old_ckpt_name): current_step.value = global_step with accelerator.accumulate(network): on_step_start(text_encoder, unet) - + + is_train = True with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) @@ -780,7 +861,7 @@ def remove_model(old_ckpt_name): # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings( @@ -810,7 +891,7 @@ def remove_model(old_ckpt_name): t.requires_grad_(True) # Predict the noise residual - with accelerator.autocast(): + with torch.set_grad_enabled(is_train), accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, @@ -844,7 +925,7 @@ def remove_model(old_ckpt_name): loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - + accelerator.backward(loss) if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually @@ -898,14 +979,38 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - + + if global_step % 25 == 0: + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + + with torch.no_grad(): + val_dataloader_iter = iter(val_dataloader) + batch = next(val_dataloader_iter) + is_train = False + loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + + current_loss = loss.detach().item() + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_current": current_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} + logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) + if len(val_dataloader) > 0: + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_epoch_average": avr_loss} + accelerator.log(logs, step=epoch + 1) + accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 @@ -1045,6 +1150,18 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--validation_seed", + type=int, + default=None, + help="Validation seed" + ) + parser.add_argument( + "--validation_split", + type=float, + default=0.0, + help="Split for validation images out of the training dataset" + ) return parser From 78cfb01922ff97bbc62ff12a4d69eaaa2d89d7c1 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 18:55:48 +0800 Subject: [PATCH 010/348] improve --- library/config_util.py | 260 +++++++++++++++++++++++++++++------------ train_network.py | 67 +++++++---- 2 files changed, 234 insertions(+), 93 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 17fc17818..d198cee35 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -41,12 +41,17 @@ DatasetGroup, ) from .utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + def add_config_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + parser.add_argument( + "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル" + ) # TODO: inherit Params class in Subset, Dataset @@ -60,6 +65,8 @@ class BaseSubsetParams: caption_separator: str = (",",) keep_tokens: int = 0 keep_tokens_separator: str = (None,) + secondary_separator: Optional[str] = None + enable_wildcard: bool = False color_aug: bool = False flip_aug: bool = False face_crop_aug_range: Optional[Tuple[float, float]] = None @@ -181,6 +188,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "shuffle_caption": bool, "keep_tokens": int, "keep_tokens_separator": str, + "secondary_separator": str, + "enable_wildcard": bool, "token_warmup_min": int, "token_warmup_step": Any(float, int), "caption_prefix": str, @@ -247,9 +256,10 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] } def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: - assert ( - support_dreambooth or support_finetuning or support_controlnet - ), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + assert support_dreambooth or support_finetuning or support_controlnet, ( + "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more." + + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。" + ) self.db_subset_schema = self.__merge_dict( self.SUBSET_ASCENDABLE_SCHEMA, @@ -361,7 +371,9 @@ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> return self.argparse_config_validator(argparse_namespace) except MultipleInvalid: # XXX: this should be a bug - logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") + logger.error( + "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。" + ) raise # NOTE: value would be overwritten by latter dict if there is already the same key @@ -447,7 +459,6 @@ def search_value(key: str, fallbacks: Sequence[dict], default_value=None): return default_value - def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] @@ -467,7 +478,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu datasets.append(dataset) val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] - + for dataset_blueprint in dataset_group_blueprint.datasets: if dataset_blueprint.params.validation_split <= 0.0: continue @@ -485,75 +496,174 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) - def print_info(_datasets): - info = "" - for i, dataset in enumerate(_datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent(f"""\ - [Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - """) + # print info + info = "" + for i, dataset in enumerate(datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent( + f"""\ + [Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} + """ + ) if dataset.enable_bucket: - info += indent(dedent(f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n"""), " ") + info += indent( + dedent( + f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n""" + ), + " ", + ) else: info += "\n" + for j, subset in enumerate(dataset.subsets): - info += indent(dedent(f"""\ - [Subset {j} of Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - shuffle_caption: {subset.shuffle_caption} - keep_tokens: {subset.keep_tokens} - caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} - caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} - caption_prefix: {subset.caption_prefix} - caption_suffix: {subset.caption_suffix} - color_aug: {subset.color_aug} - flip_aug: {subset.flip_aug} - face_crop_aug_range: {subset.face_crop_aug_range} - random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - """), " ") - - if is_dreambooth: - info += indent(dedent(f"""\ - is_reg: {subset.is_reg} - class_tokens: {subset.class_tokens} - caption_extension: {subset.caption_extension} - \n"""), " ") - elif not is_controlnet: - info += indent(dedent(f"""\ - metadata_file: {subset.metadata_file} - \n"""), " ") - - print(info) - - print_info(datasets) - - if len(val_datasets) > 0: - print("Validation dataset") - print_info(val_datasets) - + info += indent( + dedent( + f"""\ + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + keep_tokens_separator: {subset.keep_tokens_separator} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + """ + ), + " ", + ) + + if is_dreambooth: + info += indent( + dedent( + f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n""" + ), + " ", + ) + elif not is_controlnet: + info += indent( + dedent( + f"""\ + metadata_file: {subset.metadata_file} + \n""" + ), + " ", + ) + + logger.info(f'{info}') + + # print validation info + info = "" + for i, dataset in enumerate(val_datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent( + f"""\ + [Validation Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} + """ + ) + + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n""" + ), + " ", + ) + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + keep_tokens_separator: {subset.keep_tokens_separator} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + """ + ), + " ", + ) + + if is_dreambooth: + info += indent( + dedent( + f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n""" + ), + " ", + ) + elif not is_controlnet: + info += indent( + dedent( + f"""\ + metadata_file: {subset.metadata_file} + \n""" + ), + " ", + ) + + logger.info(f'{info}') + # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): - print(f"[Dataset {i}]") + logger.info(f"[Dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) - + for i, dataset in enumerate(val_datasets): print(f"[Validation Dataset {i}]") dataset.make_buckets() @@ -562,8 +672,8 @@ def print_info(_datasets): return ( DatasetGroup(datasets), DatasetGroup(val_datasets) if val_datasets else None - ) - + ) + def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): def extract_dreambooth_params(name: str) -> Tuple[int, str]: tokens = name.split("_") @@ -642,13 +752,17 @@ def load_user_config(file: str) -> dict: with open(file, "r") as f: config = json.load(f) except Exception: - logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + logger.error( + f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) raise elif file.name.lower().endswith(".toml"): try: config = toml.load(file) except Exception: - logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + logger.error( + f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) raise else: raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") @@ -675,13 +789,13 @@ def load_user_config(file: str) -> dict: train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) logger.info("[argparse_namespace]") - logger.info(f'{vars(argparse_namespace)}') + logger.info(f"{vars(argparse_namespace)}") user_config = load_user_config(config_args.dataset_config) logger.info("") logger.info("[user_config]") - logger.info(f'{user_config}') + logger.info(f"{user_config}") sanitizer = ConfigSanitizer( config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout @@ -690,10 +804,10 @@ def load_user_config(file: str) -> dict: logger.info("") logger.info("[sanitized_user_config]") - logger.info(f'{sanitized_user_config}') + logger.info(f"{sanitized_user_config}") blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) logger.info("") logger.info("[blueprint]") - logger.info(f'{blueprint}') + logger.info(f"{blueprint}") diff --git a/train_network.py b/train_network.py index db7000e82..d3e34eb7e 100644 --- a/train_network.py +++ b/train_network.py @@ -44,6 +44,7 @@ setup_logging() import logging +import itertools logger = logging.getLogger(__name__) @@ -438,6 +439,7 @@ def train(self, args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) + cyclic_val_dataloader = itertools.cycle(val_dataloader) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -979,23 +981,24 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - - if global_step % 25 == 0: - if len(val_dataloader) > 0: - print("Validating バリデーション処理...") - - with torch.no_grad(): - val_dataloader_iter = iter(val_dataloader) - batch = next(val_dataloader_iter) - is_train = False - loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - - current_loss = loss.detach().item() - val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + + if args.validation_every_n_step is not None: + if global_step % (args.validation_every_n_step) == 0: + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + for val_step in min(len(val_dataloader), args.validation_batches): + is_train = False + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / args.validation_batches + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) if args.logging_dir is not None: avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation_current": current_loss} + logs = {"loss/avr_val_loss": avr_loss} accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: @@ -1005,12 +1008,24 @@ def remove_model(old_ckpt_name): logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) - if len(val_dataloader) > 0: - if args.logging_dir is not None: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation_epoch_average": avr_loss} - accelerator.log(logs, step=epoch + 1) - + if args.validation_every_n_step is None: + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + for val_step in min(len(val_dataloader), args.validation_batches): + is_train = False + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / args.validation_batches + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/val_epoch_average": avr_loss} + accelerator.log(logs, step=epoch + 1) + accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 @@ -1162,6 +1177,18 @@ def setup_parser() -> argparse.ArgumentParser: default=0.0, help="Split for validation images out of the training dataset" ) + parser.add_argument( + "--validation_every_n_step", + type=int, + default=None, + help="Number of steps for counting validation loss. By default, validation per epoch is performed" + ) + parser.add_argument( + "--validation_batches", + type=int, + default=1, + help="Number of val steps for counting validation loss. By default, validation one batch is performed" + ) return parser From 923b761ce3622a3132bf0db7768e6b97df21c607 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 20:01:40 +0800 Subject: [PATCH 011/348] Update train_network.py --- train_network.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index d3e34eb7e..821100666 100644 --- a/train_network.py +++ b/train_network.py @@ -988,6 +988,7 @@ def remove_model(old_ckpt_name): print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): + validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) for val_step in min(len(val_dataloader), args.validation_batches): is_train = False batch = next(cyclic_val_dataloader) @@ -1013,6 +1014,7 @@ def remove_model(old_ckpt_name): print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): + validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) for val_step in min(len(val_dataloader), args.validation_batches): is_train = False batch = next(cyclic_val_dataloader) @@ -1186,8 +1188,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--validation_batches", type=int, - default=1, - help="Number of val steps for counting validation loss. By default, validation one batch is performed" + default=None, + help="Number of val steps for counting validation loss. By default, validation for all val_dataset is performed" ) return parser From 47359b8fac9602415f56b1f7e3f25a00255a1d78 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 20:17:40 +0800 Subject: [PATCH 012/348] Update train_network.py --- train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 821100666..d549378cc 100644 --- a/train_network.py +++ b/train_network.py @@ -989,7 +989,7 @@ def remove_model(old_ckpt_name): total_loss = 0.0 with torch.no_grad(): validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) - for val_step in min(len(val_dataloader), args.validation_batches): + for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) @@ -1015,7 +1015,7 @@ def remove_model(old_ckpt_name): total_loss = 0.0 with torch.no_grad(): validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) - for val_step in min(len(val_dataloader), args.validation_batches): + for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) From a51723cc2a3dd50b45e60945f97bc5adfe753d1f Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 09:42:58 +0800 Subject: [PATCH 013/348] fix timesteps --- train_network.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/train_network.py b/train_network.py index d549378cc..f0f27ea74 100644 --- a/train_network.py +++ b/train_network.py @@ -141,7 +141,6 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va total_loss = 0.0 timesteps_list = [10, 350, 500, 650, 990] - with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) @@ -174,16 +173,17 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, _ = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) - for timesteps in timesteps_list: - # Predict the noise residual + + for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(is_train), accelerator.autocast(): + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] + timesteps = torch.randint(fixed_timesteps, fixed_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noise_pred = self.call_unet( args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype ) - if args.v_parameterization: # v-parameterization training target = noise_scheduler.get_velocity(latents, noise, timesteps) @@ -988,7 +988,7 @@ def remove_model(old_ckpt_name): print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): - validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) + validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) @@ -999,7 +999,7 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/avr_val_loss": avr_loss} + logs = {"loss/average_val_loss": avr_loss} accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: @@ -1014,7 +1014,7 @@ def remove_model(old_ckpt_name): print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): - validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) + validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) From 7d84ac2177a603e9aa6834fd1c0ee19a463eb5a0 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 14:41:51 +0800 Subject: [PATCH 014/348] only use train subset to val --- library/config_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/config_util.py b/library/config_util.py index d198cee35..1a6cef971 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -492,7 +492,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu subset_klass = FineTuningSubset dataset_klass = FineTuningDataset - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False] dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) From befbec5335ed1f8018d22b65993b376571ea2989 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 18:47:04 +0800 Subject: [PATCH 015/348] Update train_network.py --- train_network.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/train_network.py b/train_network.py index f0f27ea74..cbc107b6b 100644 --- a/train_network.py +++ b/train_network.py @@ -174,7 +174,7 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - for fixed_timesteps in timesteps_list: + for fixed_timesteps in tqdm(timesteps_list, desc='Training Progress'): with torch.set_grad_enabled(is_train), accelerator.autocast(): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] @@ -184,16 +184,16 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va noise_pred = self.call_unet( args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype ) - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - total_loss += loss + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + total_loss += loss average_loss = total_loss / len(timesteps_list) return average_loss @@ -985,7 +985,7 @@ def remove_model(old_ckpt_name): if args.validation_every_n_step is not None: if global_step % (args.validation_every_n_step) == 0: if len(val_dataloader) > 0: - print("Validating バリデーション処理...") + print(f"\nValidating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) @@ -994,10 +994,12 @@ def remove_model(old_ckpt_name): batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) total_loss += loss.detach().item() - current_loss = total_loss / args.validation_batches - val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + current_loss = total_loss / args.validation_batches + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) avr_loss: float = val_loss_recorder.moving_average logs = {"loss/average_val_loss": avr_loss} accelerator.log(logs, step=global_step) @@ -1011,7 +1013,7 @@ def remove_model(old_ckpt_name): if args.validation_every_n_step is None: if len(val_dataloader) > 0: - print("Validating バリデーション処理...") + print(f"\nValidating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) @@ -1025,7 +1027,7 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/val_epoch_average": avr_loss} + logs = {"loss/epoch_val_average": avr_loss} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() From 63e58f78e3df7608045071cdc247bb26bd19a333 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 19:15:55 +0800 Subject: [PATCH 016/348] Update train_network.py --- train_network.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index cbc107b6b..82d72df24 100644 --- a/train_network.py +++ b/train_network.py @@ -178,8 +178,7 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va with torch.set_grad_enabled(is_train), accelerator.autocast(): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] - timesteps = torch.randint(fixed_timesteps, fixed_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() + timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noise_pred = self.call_unet( args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype From a6c41c6bea0465112c7bd472dff68b7e8ecea46e Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 19:23:48 +0800 Subject: [PATCH 017/348] Update train_network.py --- train_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 82d72df24..6eefdb2be 100644 --- a/train_network.py +++ b/train_network.py @@ -174,7 +174,7 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - for fixed_timesteps in tqdm(timesteps_list, desc='Training Progress'): + for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(is_train), accelerator.autocast(): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] @@ -988,7 +988,7 @@ def remove_model(old_ckpt_name): total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) - for val_step in range(validation_steps): + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) @@ -1016,7 +1016,7 @@ def remove_model(old_ckpt_name): total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) - for val_step in range(validation_steps): + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) From bd7e2295b7c4d1444a9e844309e1685cb29c6961 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Wed, 13 Mar 2024 17:54:21 +0800 Subject: [PATCH 018/348] fix --- train_network.py | 38 +++++++++----------------------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/train_network.py b/train_network.py index 6eefdb2be..128690fba 100644 --- a/train_network.py +++ b/train_network.py @@ -981,20 +981,19 @@ def remove_model(old_ckpt_name): logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if args.validation_every_n_step is not None: - if global_step % (args.validation_every_n_step) == 0: - if len(val_dataloader) > 0: + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or step == len(train_dataloader) - 1 or global_step >= args.max_train_steps: print(f"\nValidating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): - validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) for val_step in tqdm(range(validation_steps), desc='Validation Steps'): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) total_loss += loss.detach().item() - current_loss = total_loss / args.validation_batches - val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=epoch, step=step, loss=current_loss) if args.logging_dir is not None: logs = {"loss/current_val_loss": current_loss} @@ -1009,25 +1008,6 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) - - if args.validation_every_n_step is None: - if len(val_dataloader) > 0: - print(f"\nValidating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - is_train = False - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / args.validation_batches - val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/epoch_val_average": avr_loss} - accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() @@ -1184,14 +1164,14 @@ def setup_parser() -> argparse.ArgumentParser: "--validation_every_n_step", type=int, default=None, - help="Number of steps for counting validation loss. By default, validation per epoch is performed" + help="Number of train steps for counting validation loss. By default, validation per train epoch is performed" ) parser.add_argument( - "--validation_batches", + "--max_validation_steps", type=int, default=None, - help="Number of val steps for counting validation loss. By default, validation for all val_dataset is performed" - ) + help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset" + ) return parser From d05965dbadf430dab6a05f171292f6d2077ec946 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Wed, 13 Mar 2024 18:33:51 +0800 Subject: [PATCH 019/348] Update train_network.py --- train_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 864bfd708..cc9fcbbed 100644 --- a/train_network.py +++ b/train_network.py @@ -987,8 +987,8 @@ def remove_model(old_ckpt_name): accelerator.log(logs, step=global_step) if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or step == len(train_dataloader) - 1 or global_step >= args.max_train_steps: - print(f"\nValidating バリデーション処理...") + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) @@ -998,7 +998,7 @@ def remove_model(old_ckpt_name): loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) total_loss += loss.detach().item() current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) if args.logging_dir is not None: logs = {"loss/current_val_loss": current_loss} From b5e8045df40ed4a437492ed2b6ea6d5be7282080 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Sat, 16 Mar 2024 11:51:11 +0800 Subject: [PATCH 020/348] fix control net --- library/config_util.py | 6 ++++-- library/train_util.py | 15 ++++++++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index ec6ef4b2b..0da0b1437 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -491,8 +491,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset - - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False] + if subset_klass == DreamBoothSubset: + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False] + else: + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) diff --git a/library/train_util.py b/library/train_util.py index 892979628..ae7968d73 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1816,6 +1816,7 @@ class ControlNetDataset(BaseDataset): def __init__( self, subsets: Sequence[ControlNetSubset], + is_train: bool, batch_size: int, tokenizer, max_token_length, @@ -1826,6 +1827,8 @@ def __init__( max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, + validation_split: float, + validation_seed: Optional[int], debug_dataset: float, ) -> None: super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) @@ -1860,6 +1863,7 @@ def __init__( self.dreambooth_dataset_delegate = DreamBoothDataset( db_subsets, + is_train, batch_size, tokenizer, max_token_length, @@ -1871,6 +1875,8 @@ def __init__( bucket_reso_steps, bucket_no_upscale, 1.0, + validation_split, + validation_seed, debug_dataset, ) @@ -1878,7 +1884,10 @@ def __init__( self.image_data = self.dreambooth_dataset_delegate.image_data self.batch_size = batch_size self.num_train_images = self.dreambooth_dataset_delegate.num_train_images - self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + self.is_train = is_train + self.validation_split = validation_split + self.validation_seed = validation_seed # assert all conditioning data exists missing_imgs = [] @@ -1911,8 +1920,8 @@ def __init__( [cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img] ) - assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" - assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" + #assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" + #assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" self.conditioning_image_transforms = IMAGE_TRANSFORMS From 36d4023431d10718b00673d5ba34f426690c62de Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 01:39:17 +0800 Subject: [PATCH 021/348] Update config_util.py --- library/config_util.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index a7e0024e3..c6667690e 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -498,10 +498,21 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset - if subset_klass == DreamBoothSubset: - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False] - else: - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + + subsets = [] + for subset_blueprint in dataset_blueprint.subsets: + subset_blueprint.params.num_repeats = 1 + subset_blueprint.params.color_aug = False + subset_blueprint.params.flip_aug = False + subset_blueprint.params.random_crop = False + subset_blueprint.params.random_crop = None + subset_blueprint.params.caption_dropout_rate = 0.0 + subset_blueprint.params.caption_dropout_every_n_epochs = 0 + subset_blueprint.params.caption_tag_dropout_rate = 0.0 + subset_blueprint.params.token_warmup_step = 0 + if subset_klass != DreamBoothSubset or not subset_blueprint.params.is_reg: + subsets.append(subset_klass(**asdict(subset_blueprint.params))) + dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) From 229c5a38ef4e93e2023d748b4fa1588d490340ad Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 01:45:49 +0800 Subject: [PATCH 022/348] Update train_util.py --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 832be75d5..b143e85a8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3123,7 +3123,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument( - "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする" + "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする" ) parser.add_argument( "--gradient_accumulation_steps", From 3b251b758dae6e4f11e0bbc7e544dc9542c836ff Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 01:50:32 +0800 Subject: [PATCH 023/348] Update config_util.py --- library/config_util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index c6667690e..8f01e1f60 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -510,8 +510,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu subset_blueprint.params.caption_dropout_every_n_epochs = 0 subset_blueprint.params.caption_tag_dropout_rate = 0.0 subset_blueprint.params.token_warmup_step = 0 - if subset_klass != DreamBoothSubset or not subset_blueprint.params.is_reg: - subsets.append(subset_klass(**asdict(subset_blueprint.params))) + + if subset_klass != DreamBoothSubset or (subset_klass == DreamBoothSubset and not subset_blueprint.params.is_reg): + subset = subset_klass(**asdict(subset_blueprint.params)) + subsets.append(subset) dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) From 459b12539b0ae1a92da98e38568ea0a61db1e89f Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 01:52:14 +0800 Subject: [PATCH 024/348] Update config_util.py --- library/config_util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 8f01e1f60..6f243aac3 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -512,8 +512,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu subset_blueprint.params.token_warmup_step = 0 if subset_klass != DreamBoothSubset or (subset_klass == DreamBoothSubset and not subset_blueprint.params.is_reg): - subset = subset_klass(**asdict(subset_blueprint.params)) - subsets.append(subset) + subsets.append(subset_klass(**asdict(subset_blueprint.params))) dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) From 89ad69b6a0d35791627cb58630a711befc6bb3b5 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 08:42:31 +0800 Subject: [PATCH 025/348] Update train_util.py --- library/train_util.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b143e85a8..8bf6823bb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1511,17 +1511,6 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.warning(f"not directory: {subset.image_dir}") return [], [] - img_paths = glob_images(subset.image_dir, "*") - if self.validation_split > 0.0: - img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) - logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") - - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う - captions = [] - missing_captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path, subset.caption_extension) - if cap_for_img is None and subset.class_tokens is None: info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE) use_cached_info_for_subset = subset.cache_info if use_cached_info_for_subset: @@ -1545,6 +1534,8 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # we may need to check image size and existence of image files, but it takes time, so user should check it before training else: img_paths = glob_images(subset.image_dir, "*") + if self.validation_split > 0.0: + img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) sizes = [None] * len(img_paths) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") From fde8026c2d92fe4991927eed6fa1ff373e8d38d2 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 11:29:26 +0800 Subject: [PATCH 026/348] Update config_util.py --- library/config_util.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 6f243aac3..a1b02bd1e 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -636,19 +636,11 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu [Subset {j} of Dataset {i}] image_dir: "{subset.image_dir}" image_count: {subset.img_count} - num_repeats: {subset.num_repeats} shuffle_caption: {subset.shuffle_caption} keep_tokens: {subset.keep_tokens} keep_tokens_separator: {subset.keep_tokens_separator} - caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} - caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} caption_prefix: {subset.caption_prefix} caption_suffix: {subset.caption_suffix} - color_aug: {subset.color_aug} - flip_aug: {subset.flip_aug} - face_crop_aug_range: {subset.face_crop_aug_range} - random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, """ @@ -688,7 +680,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset.set_seed(seed) for i, dataset in enumerate(val_datasets): - print(f"[Validation Dataset {i}]") + logger.info(f"[Validation Dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) From e5268286bf90ddcc53ad1deb31aba857cfa967d5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 15 Jun 2024 22:20:24 +0900 Subject: [PATCH 027/348] add sd3 models and inference script --- library/sd3_models.py | 1796 ++++++++++++++++++++++++++++++++++++++ library/sd3_utils.py | 113 +++ sd3_minimal_inference.py | 347 ++++++++ 3 files changed, 2256 insertions(+) create mode 100644 library/sd3_models.py create mode 100644 library/sd3_utils.py create mode 100644 sd3_minimal_inference.py diff --git a/library/sd3_models.py b/library/sd3_models.py new file mode 100644 index 000000000..294a69b06 --- /dev/null +++ b/library/sd3_models.py @@ -0,0 +1,1796 @@ +# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref +# the original code is licensed under the MIT License + +# and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! + +from functools import partial +import math +from typing import Dict, Optional +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +from transformers import CLIPTokenizer, T5TokenizerFast + + +memory_efficient_attention = None +try: + import xformers +except: + pass + +try: + from xformers.ops import memory_efficient_attention +except: + memory_efficient_attention = None + + +# region tokenizer +class SDTokenizer: + def __init__( + self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None + ): + """ + サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。 + Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings. + """ + self.tokenizer = tokenizer + self.max_length = max_length + self.min_length = min_length + empty = self.tokenizer("")["input_ids"] + if has_start_token: + self.tokens_start = 1 + self.start_token = empty[0] + self.end_token = empty[1] + else: + self.tokens_start = 0 + self.start_token = None + self.end_token = empty[0] + self.pad_with_end = pad_with_end + self.pad_to_max_length = pad_to_max_length + vocab = self.tokenizer.get_vocab() + self.inv_vocab = {v: k for k, v in vocab.items()} + self.max_word_length = 8 + + def tokenize_with_weights(self, text: str): + """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. + The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" + """ + ja: テキストをトークン化し、重み値を持ちます - すべての値に1.0を仮定し、他の機能を無視します。 + 詳細は参考実装には関係なく、重み自体はSD3に対して弱い影響しかありません。へぇ~ + """ + if self.pad_with_end: + pad_token = self.end_token + else: + pad_token = 0 + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0)) + to_tokenize = text.replace("\n", " ").split(" ") + to_tokenize = [x for x in to_tokenize if x != ""] + for word in to_tokenize: + batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]]) + batch.append((self.end_token, 1.0)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) + if self.min_length is not None and len(batch) < self.min_length: + batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) + return [batch] + + +class T5XXLTokenizer(SDTokenizer): + """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + + def __init__(self): + super().__init__( + pad_with_end=False, + tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), + has_start_token=False, + pad_to_max_length=False, + max_length=99999999, + min_length=77, + ) + + +class SDXLClipGTokenizer(SDTokenizer): + def __init__(self, tokenizer): + super().__init__(pad_with_end=False, tokenizer=tokenizer) + + +class SD3Tokenizer: + def __init__(self, t5xxl=True): + # TODO cache tokenizer settings locally or hold them in the repo like ComfyUI + clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) + self.clip_g = SDXLClipGTokenizer(clip_tokenizer) + self.t5xxl = T5XXLTokenizer() if t5xxl else None + + def tokenize_with_weights(self, text: str): + return ( + self.clip_l.tokenize_with_weights(text), + self.clip_g.tokenize_with_weights(text), + self.t5xxl.tokenize_with_weights(text) if self.t5xxl is not None else None, + ) + + +# endregion + +# region mmdit + + +def get_2d_sincos_pos_embed( + embed_dim, + grid_size, + scaling_factor=None, + offset=None, +): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + if scaling_factor is not None: + grid = grid / scaling_factor + if offset is not None: + grid = grid - offset + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid_torch( + embed_dim, + pos, + device=None, + dtype=torch.float32, +): + omega = torch.arange(embed_dim // 2, device=device, dtype=dtype) + omega *= 2.0 / embed_dim + omega = 1.0 / 10000**omega + out = torch.outer(pos.reshape(-1), omega) + emb = torch.cat([out.sin(), out.cos()], dim=1) + return emb + + +def get_2d_sincos_pos_embed_torch( + embed_dim, + w, + h, + val_center=7.5, + val_magnitude=7.5, + device=None, + dtype=torch.float32, +): + small = min(h, w) + val_h = (h / small) * val_magnitude + val_w = (w / small) * val_magnitude + grid_h, grid_w = torch.meshgrid( + torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), + torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), + indexing="ij", + ) + emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) + emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) + emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) + return emb + + +def modulate(x, shift, scale): + if shift is None: + shift = torch.zeros_like(scale) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def default(x, default_value): + if x is None: + return default_value + return x + + +def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + # device=t.device, dtype=t.dtype + # ) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(dtype=t.dtype) + return embedding + + +def rmsnorm(x, eps=1e-6): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +class PatchEmbed(nn.Module): + def __init__( + self, + img_size=256, + patch_size=4, + in_channels=3, + embed_dim=512, + norm_layer=None, + flatten=True, + bias=True, + strict_img_size=True, + dynamic_img_pad=True, + ): + super().__init__() + self.patch_size = patch_size + self.flatten = flatten + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad + if img_size is not None: + self.img_size = img_size + self.grid_size = img_size // patch_size + self.num_patches = self.grid_size**2 + else: + self.img_size = None + self.grid_size = None + self.num_patches = None + + self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias) + self.norm = nn.Identity() if norm_layer is None else norm_layer(embed_dim) + + def forward(self, x): + B, C, H, W = x.shape + + if self.dynamic_img_pad: + # Pad input so we won't have partial patch + pad_h = (self.patch_size - H % self.patch_size) % self.patch_size + pad_w = (self.patch_size - W % self.patch_size) % self.patch_size + x = nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="reflect") + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +# FinalLayer in mmdit.py +class UnPatch(nn.Module): + def __init__(self, hidden_size=512, patch_size=4, out_channels=3): + super().__init__() + self.patch_size = patch_size + self.c = out_channels + + # eps is default in mmdit.py + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size), + ) + + def forward(self, x: torch.Tensor, cmod, H=None, W=None): + b, n, _ = x.shape + p = self.patch_size + c = self.c + if H is None and W is None: + w = h = int(n**0.5) + assert h * w == n + else: + h = H // p if H else n // (W // p) + w = W // p if W else n // h + assert h * w == n + + shift, scale = self.adaLN_modulation(cmod).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + + x = x.view(b, h, w, p, p, c) + x = x.permute(0, 5, 1, 3, 2, 4).contiguous() + x = x.view(b, c, h * p, w * p) + return x + + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=lambda: nn.GELU(), + norm_layer=None, + bias=True, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.use_conv = use_conv + + layer = partial(nn.Conv1d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = layer(in_features, hidden_features, bias=bias) + self.fc2 = layer(hidden_features, out_features, bias=bias) + self.act = act_layer() + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.norm(x) + x = self.fc2(x) + return x + + +class TimestepEmbedding(nn.Module): + def __init__(self, hidden_size, freq_embed_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(freq_embed_size, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + self.freq_embed_size = freq_embed_size + + def forward(self, t, dtype=None, **kwargs): + t_freq = timestep_embedding(t, self.freq_embed_size).to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class Embedder(nn.Module): + def __init__(self, input_dim, hidden_size): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + + def forward(self, x): + return self.mlp(x) + + +class RMSNorm(torch.nn.Module): + def __init__( + self, + dim: int, + elementwise_affine: bool = False, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + """ + super().__init__() + self.eps = eps + self.learnable_scale = elementwise_affine + if self.learnable_scale: + self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + else: + self.register_parameter("weight", None) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + """ + x = rmsnorm(x, eps=self.eps) + if self.learnable_scale: + return x * self.weight.to(device=x.device, dtype=x.dtype) + else: + return x + + +class SwiGLUFeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: float = None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +# Linears for SelfAttention in mmdit.py +class AttentionLinears(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + pre_only: bool = False, + qk_norm: str = None, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + if not pre_only: + self.proj = nn.Linear(dim, dim) + self.pre_only = pre_only + + if qk_norm == "rms": + self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + elif qk_norm == "ln": + self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + elif qk_norm is None: + self.ln_q = nn.Identity() + self.ln_k = nn.Identity() + else: + raise ValueError(qk_norm) + + def pre_attention(self, x: torch.Tensor) -> torch.Tensor: + """ + output: + q, k, v: [B, L, D] + """ + B, L, C = x.shape + qkv: torch.Tensor = self.qkv(x) + q, k, v = qkv.reshape(B, L, -1, self.head_dim).chunk(3, dim=2) + q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1) + k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1) + return (q, k, v) + + def post_attention(self, x: torch.Tensor) -> torch.Tensor: + assert not self.pre_only + x = self.proj(x) + return x + + +MEMORY_LAYOUTS = { + "torch": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2), + lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1), + lambda x: (1, x, 1, 1), + ), + "xformers": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim), + lambda x: x.reshape(x.shape[0], x.shape[1], -1), + lambda x: (1, 1, x, 1), + ), + "math": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2), + lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1), + lambda x: (1, x, 1, 1), + ), +} +# ATTN_FUNCTION = { +# "torch": F.scaled_dot_product_attention, +# "xformers": memory_efficient_attention, +# } + + +def vanilla_attention(q, k, v, mask, scale=None): + if scale is None: + scale = math.sqrt(q.size(-1)) + scores = torch.bmm(q, k.transpose(-1, -2)) / scale + if mask is not None: + mask = einops.rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(scores.dtype).max + mask = einops.repeat(mask, "b j -> (b h) j", h=q.size(-3)) + scores = scores.masked_fill(~mask, max_neg_value) + p_attn = F.softmax(scores, dim=-1) + return torch.bmm(p_attn, v) + + +def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"): + """ + q, k, v: [B, L, D] + """ + pre_attn_layout = MEMORY_LAYOUTS[mode][0] + post_attn_layout = MEMORY_LAYOUTS[mode][1] + q = pre_attn_layout(q, head_dim) + k = pre_attn_layout(k, head_dim) + v = pre_attn_layout(v, head_dim) + + # scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale) + if mode == "torch": + assert scale is None + scores = F.scaled_dot_product_attention(q, k.to(q), v.to(q), mask) # , scale=scale) + elif mode == "xformers": + scores = memory_efficient_attention(q, k.to(q), v.to(q), mask, scale=scale) + else: + scores = vanilla_attention(q, k.to(q), v.to(q), mask, scale=scale) + + scores = post_attn_layout(scores) + return scores + + +class SelfAttention(AttentionLinears): + def __init__(self, dim, num_heads=8, mode="xformers"): + super().__init__(dim, num_heads, qkv_bias=True, pre_only=False) + assert mode in MEMORY_LAYOUTS + self.head_dim = dim // num_heads + self.attn_mode = mode + + def set_attn_mode(self, mode): + self.attn_mode = mode + + def forward(self, x): + q, k, v = self.pre_attention(x) + attn_score = attention(q, k, v, self.head_dim, mode=self.attn_mode) + return self.post_attention(attn_score) + + +class TransformerBlock(nn.Module): + def __init__(self, context_size, mode="xformers"): + super().__init__() + self.context_size = context_size + self.norm1 = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) + self.attn = SelfAttention(context_size, mode=mode) + self.norm2 = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) + self.mlp = MLP( + in_features=context_size, + hidden_features=context_size * 4, + act_layer=lambda: nn.GELU(approximate="tanh"), + ) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, context_size, num_layers, mode="xformers"): + super().__init__() + self.layers = nn.ModuleList([TransformerBlock(context_size, mode) for _ in range(num_layers)]) + self.norm = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return self.norm(x) + + +# DismantledBlock in mmdit.py +class SingleDiTBlock(nn.Module): + """ + A DiT block with gated adaptive layer norm (adaLN) conditioning. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: str = "xformers", + qkv_bias: bool = False, + pre_only: bool = False, + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + qk_norm: Optional[str] = None, + **block_kwargs, + ): + super().__init__() + assert attn_mode in MEMORY_LAYOUTS + self.attn_mode = attn_mode + if not rmsnorm: + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + else: + self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = AttentionLinears( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + pre_only=pre_only, + qk_norm=qk_norm, + ) + if not pre_only: + if not rmsnorm: + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + else: + self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + if not pre_only: + if not swiglu: + self.mlp = MLP( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=lambda: nn.GELU(approximate="tanh"), + ) + else: + self.mlp = SwiGLUFeedForward( + dim=hidden_size, + hidden_dim=mlp_hidden_dim, + multiple_of=256, + ) + self.scale_mod_only = scale_mod_only + if not scale_mod_only: + n_mods = 6 if not pre_only else 2 + else: + n_mods = 4 if not pre_only else 1 + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size)) + self.pre_only = pre_only + + def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + if not self.pre_only: + if not self.scale_mod_only: + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation( + c + ).chunk(6, dim=-1) + else: + shift_msa = None + shift_mlp = None + ( + scale_msa, + gate_msa, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation( + c + ).chunk(4, dim=-1) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, ( + x, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) + else: + if not self.scale_mod_only: + ( + shift_msa, + scale_msa, + ) = self.adaLN_modulation( + c + ).chunk(2, dim=-1) + else: + shift_msa = None + scale_msa = self.adaLN_modulation(c) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, None + + def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): + assert not self.pre_only + x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +# JointBlock + block_mixing in mmdit.py +class MMDiTBlock(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + pre_only = kwargs.pop("pre_only") + self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs) + self.x_block = SingleDiTBlock(*args, pre_only=False, **kwargs) + self.head_dim = self.x_block.attn.head_dim + self.mode = self.x_block.attn_mode + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def _forward(self, context, x, c): + ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c) + x_qkv, x_intermediate = self.x_block.pre_attention(x, c) + + ctx_len = ctx_qkv[0].size(1) + + q = torch.concat((ctx_qkv[0], x_qkv[0]), dim=1) + k = torch.concat((ctx_qkv[1], x_qkv[1]), dim=1) + v = torch.concat((ctx_qkv[2], x_qkv[2]), dim=1) + + attn = attention(q, k, v, head_dim=self.head_dim, mode=self.mode) + ctx_attn_out = attn[:, :ctx_len] + x_attn_out = attn[:, ctx_len:] + + x = self.x_block.post_attention(x_attn_out, *x_intermediate) + if not self.context_block.pre_only: + context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate) + else: + context = None + return context, x + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + +class MMDiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_size: int = 32, + patch_size: int = 2, + in_channels: int = 4, + depth: int = 28, + # hidden_size: Optional[int] = None, + # num_heads: Optional[int] = None, + mlp_ratio: float = 4.0, + learn_sigma: bool = False, + adm_in_channels: Optional[int] = None, + context_embedder_config: Optional[Dict] = None, + use_checkpoint: bool = False, + register_length: int = 0, + attn_mode: str = "torch", + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + out_channels: Optional[int] = None, + pos_embed_scaling_factor: Optional[float] = None, + pos_embed_offset: Optional[float] = None, + pos_embed_max_size: Optional[int] = None, + num_patches=None, + qk_norm: Optional[str] = None, + qkv_bias: bool = True, + context_processor_layers=None, + context_size=4096, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + default_out_channels = in_channels * 2 if learn_sigma else in_channels + self.out_channels = default(out_channels, default_out_channels) + self.patch_size = patch_size + self.pos_embed_scaling_factor = pos_embed_scaling_factor + self.pos_embed_offset = pos_embed_offset + self.pos_embed_max_size = pos_embed_max_size + self.gradient_checkpointing = use_checkpoint + + # hidden_size = default(hidden_size, 64 * depth) + # num_heads = default(num_heads, hidden_size // 64) + + # apply magic --> this defines a head_size of 64 + self.hidden_size = 64 * depth + num_heads = depth + + self.num_heads = num_heads + + self.x_embedder = PatchEmbed( + input_size, + patch_size, + in_channels, + self.hidden_size, + bias=True, + strict_img_size=self.pos_embed_max_size is None, + ) + self.t_embedder = TimestepEmbedding(self.hidden_size) + + self.y_embedder = None + if adm_in_channels is not None: + assert isinstance(adm_in_channels, int) + self.y_embedder = Embedder(adm_in_channels, self.hidden_size) + + if context_processor_layers is not None: + self.context_processor = Transformer(context_size, context_processor_layers, attn_mode) + else: + self.context_processor = None + + self.context_embedder = nn.Linear(context_size, self.hidden_size) + self.register_length = register_length + if self.register_length > 0: + self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size)) + + # num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + # just use a buffer already + if num_patches is not None: + self.register_buffer( + "pos_embed", + torch.empty(1, num_patches, self.hidden_size), + ) + else: + self.pos_embed = None + + self.use_checkpoint = use_checkpoint + self.joint_blocks = nn.ModuleList( + [ + MMDiTBlock( + self.hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + qkv_bias=qkv_bias, + pre_only=i == depth - 1, + rmsnorm=rmsnorm, + scale_mod_only=scale_mod_only, + swiglu=swiglu, + qk_norm=qk_norm, + ) + for i in range(depth) + ] + ) + for block in self.joint_blocks: + block.gradient_checkpointing = use_checkpoint + + self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels) + # self.initialize_weights() + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + for block in self.joint_blocks: + block.enable_gradient_checkpointing() + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + for block in self.joint_blocks: + block.disable_gradient_checkpointing() + + def initialize_weights(self): + # TODO: Init context_embedder? + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding + if self.pos_embed is not None: + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + int(self.pos_embed.shape[-2] ** 0.5), + scaling_factor=self.pos_embed_scaling_factor, + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + if getattr(self, "y_embedder", None) is not None: + nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.joint_blocks: + nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def cropped_pos_embed(self, h, w, device=None): + p = self.x_embedder.patch_size + # patched size + h = (h + 1) // p + w = (w + 1) // p + if self.pos_embed is None: + return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device) + assert self.pos_embed_max_size is not None + assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) + assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size) + top = (self.pos_embed_max_size - h) // 2 + left = (self.pos_embed_max_size - w) // 2 + spatial_pos_embed = self.pos_embed.reshape( + 1, + self.pos_embed_max_size, + self.pos_embed_max_size, + self.pos_embed.shape[-1], + ) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + return spatial_pos_embed + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, D) tensor of class labels + """ + + if self.context_processor is not None: + context = self.context_processor(context) + + B, C, H, W = x.shape + x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device).to(dtype=x.dtype) + c = self.t_embedder(t, dtype=x.dtype) # (N, D) + if y is not None and self.y_embedder is not None: + y = self.y_embedder(y) # (N, D) + c = c + y # (N, D) + + if context is not None: + context = self.context_embedder(context) + + if self.register_length > 0: + context = torch.cat( + ( + einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), + default(context, torch.Tensor([]).type_as(x)), + ), + 1, + ) + + for block in self.joint_blocks: + context, x = block(context, x, c) + x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify + return x[:, :, :H, :W] + + +def create_mmdit_sd3_medium_configs(attn_mode: str): + # {'patch_size': 2, 'depth': 24, 'num_patches': 36864, + # 'pos_embed_max_size': 192, 'adm_in_channels': 2048, 'context_embedder': + # {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}}} + mmdit = MMDiT( + input_size=None, + pos_embed_max_size=192, + patch_size=2, + in_channels=16, + adm_in_channels=2048, + depth=24, + mlp_ratio=4, + qk_norm=None, + num_patches=36864, + context_size=4096, + attn_mode=attn_mode, + ) + return mmdit + + +# endregion + +# region VAE + + +def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) + + +class ResnetBlock(torch.nn.Module): + def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = Normalize(in_channels, dtype=dtype, device=device) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.norm2 = Normalize(out_channels, dtype=dtype, device=device) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + if self.in_channels != self.out_channels: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device + ) + else: + self.nin_shortcut = None + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + hidden = x + hidden = self.norm1(hidden) + hidden = self.swish(hidden) + hidden = self.conv1(hidden) + hidden = self.norm2(hidden) + hidden = self.swish(hidden) + hidden = self.conv2(hidden) + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + return x + hidden + + +class AttnBlock(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.norm = Normalize(in_channels, dtype=dtype, device=device) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + + def forward(self, x): + hidden = self.norm(x) + q = self.q(hidden) + k = self.k(hidden) + v = self.v(hidden) + b, c, h, w = q.shape + q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)) + hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default + hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + hidden = self.proj_out(hidden) + return x + hidden + + +class Downsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class VAEEncoder(torch.nn.Module): + def __init__( + self, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = torch.nn.ModuleList() + for i_level in range(self.num_resolutions): + block = torch.nn.ModuleList() + attn = torch.nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + down = torch.nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, dtype=dtype, device=device) + self.down.append(down) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = self.swish(h) + h = self.conv_out(h) + return h + + +class VAEDecoder(torch.nn.Module): + def __init__( + self, + ch=128, + out_ch=3, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + resolution=256, + z_channels=16, + dtype=torch.float32, + device=None, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # upsampling + self.up = torch.nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = torch.nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + up = torch.nn.Module() + up.block = block + if i_level != 0: + up.upsample = Upsample(block_in, dtype=dtype, device=device) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, z): + # z to block_in + hidden = self.conv_in(z) + # middle + hidden = self.mid.block_1(hidden) + hidden = self.mid.attn_1(hidden) + hidden = self.mid.block_2(hidden) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + hidden = self.up[i_level].block[i_block](hidden) + if i_level != 0: + hidden = self.up[i_level].upsample(hidden) + # end + hidden = self.norm_out(hidden) + hidden = self.swish(hidden) + hidden = self.conv_out(hidden) + return hidden + + +class SDVAE(torch.nn.Module): + def __init__(self, dtype=torch.float32, device=None): + super().__init__() + self.encoder = VAEEncoder(dtype=dtype, device=device) + self.decoder = VAEDecoder(dtype=dtype, device=device) + + @torch.autocast("cuda", dtype=torch.float16) + def decode(self, latent): + return self.decoder(latent) + + @torch.autocast("cuda", dtype=torch.float16) + def encode(self, image): + hidden = self.encoder(image) + mean, logvar = torch.chunk(hidden, 2, dim=1) + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + + +# endregion + + +# region Text Encoder +class CLIPAttention(torch.nn.Module): + def __init__(self, embed_dim, heads, dtype, device, mode="xformers"): + super().__init__() + self.heads = heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.attn_mode = mode + + def set_attn_mode(self, mode): + self.attn_mode = mode + + def forward(self, x, mask=None): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + out = attention(q, k, v, self.heads, mask, mode=self.attn_mode) + return self.out_proj(out) + + +ACTIVATIONS = { + "quick_gelu": lambda: (lambda a: a * torch.sigmoid(1.702 * a)), + # "gelu": torch.nn.functional.gelu, + "gelu": lambda: nn.GELU(), +} + + +class CLIPLayer(torch.nn.Module): + def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): + super().__init__() + self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) + self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + # # self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) + # self.mlp = Mlp( + # embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device + # ) + self.mlp = MLP(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation]) + self.mlp.to(device=device, dtype=dtype) + + def forward(self, x, mask=None): + x += self.self_attn(self.layer_norm1(x), mask) + x += self.mlp(self.layer_norm2(x)) + return x + + +class CLIPEncoder(torch.nn.Module): + def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): + super().__init__() + self.layers = torch.nn.ModuleList( + [CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)] + ) + + def forward(self, x, mask=None, intermediate_output=None): + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layers) + intermediate_output + intermediate = None + for i, l in enumerate(self.layers): + x = l(x, mask) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + + +class CLIPEmbeddings(torch.nn.Module): + def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None): + super().__init__() + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) + self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens): + return self.token_embedding(input_tokens) + self.position_embedding.weight + + +class CLIPTextModel_(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + super().__init__() + self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) + self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) + self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True): + x = self.embeddings(input_tokens) + + if x.dtype == torch.bfloat16: + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=torch.float32, device=x.device).fill_(float("-inf")).triu_(1) + causal_mask = causal_mask.to(dtype=x.dtype) + else: + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + + x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output) + x = self.final_layer_norm(x) + if i is not None and final_layer_norm_intermediate: + i = self.final_layer_norm(i) + pooled_output = x[ + torch.arange(x.shape[0], device=x.device), + input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1), + ] + return x, i, pooled_output + + +class CLIPTextModel(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_hidden_layers"] + self.text_model = CLIPTextModel_(config_dict, dtype, device) + embed_dim = config_dict["hidden_size"] + self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) + self.text_projection.weight.copy_(torch.eye(embed_dim)) + self.dtype = dtype + + def get_input_embeddings(self): + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, embeddings): + self.text_model.embeddings.token_embedding = embeddings + + def forward(self, *args, **kwargs): + x = self.text_model(*args, **kwargs) + out = self.text_projection(x[2]) + return (x[0], x[1], out, x[2]) + + +class ClipTokenWeightEncoder: + def encode_token_weights(self, token_weight_pairs): + tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + out, pooled = self([tokens]) + if pooled is not None: + first_pooled = pooled[0:1].cpu() + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2).cpu(), first_pooled + + +class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + device="cpu", + max_length=77, + layer="last", + layer_idx=None, + textmodel_json_config=None, + dtype=None, + model_class=CLIPTextModel, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, + layer_norm_hidden_state=True, + return_projected_pooled=True, + ): + super().__init__() + assert layer in self.LAYERS + self.transformer = model_class(textmodel_json_config, dtype, device) + self.num_layers = self.transformer.num_layers + self.max_length = max_length + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + self.layer = layer + self.layer_idx = None + self.special_tokens = special_tokens + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.layer_norm_hidden_state = layer_norm_hidden_state + self.return_projected_pooled = return_projected_pooled + if layer == "hidden": + assert layer_idx is not None + assert abs(layer_idx) < self.num_layers + self.set_clip_options({"layer": layer_idx}) + self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) + + def set_attn_mode(self, mode): + raise NotImplementedError("This model does not support setting the attention mode") + + def set_clip_options(self, options): + layer_idx = options.get("layer", self.layer_idx) + self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) + if layer_idx is None or abs(layer_idx) > self.num_layers: + self.layer = "last" + else: + self.layer = "hidden" + self.layer_idx = layer_idx + + def forward(self, tokens): + backup_embeds = self.transformer.get_input_embeddings() + device = backup_embeds.weight.device + tokens = torch.LongTensor(tokens).to(device) + outputs = self.transformer( + tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state + ) + self.transformer.set_input_embeddings(backup_embeds) + if self.layer == "last": + z = outputs[0] + else: + z = outputs[1] + pooled_output = None + if len(outputs) >= 3: + if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: + pooled_output = outputs[3].float() + elif outputs[2] is not None: + pooled_output = outputs[2].float() + return z.float(), pooled_output + + def set_attn_mode(self, mode): + clip_text_model = self.transformer.text_model + for layer in clip_text_model.encoder.layers: + layer.self_attn.set_attn_mode(mode) + + +class SDXLClipG(SDClipModel): + """Wraps the CLIP-G model into the SD-CLIP-Model interface""" + + def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None): + if layer == "penultimate": + layer = "hidden" + layer_idx = -2 + super().__init__( + device=device, + layer=layer, + layer_idx=layer_idx, + textmodel_json_config=config, + dtype=dtype, + special_tokens={"start": 49406, "end": 49407, "pad": 0}, + layer_norm_hidden_state=False, + ) + + def set_attn_mode(self, mode): + clip_text_model = self.transformer.text_model + for layer in clip_text_model.encoder.layers: + layer.self_attn.set_attn_mode(mode) + + +class T5XXLModel(SDClipModel): + """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience""" + + def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None): + super().__init__( + device=device, + layer=layer, + layer_idx=layer_idx, + textmodel_json_config=config, + dtype=dtype, + special_tokens={"end": 1, "pad": 0}, + model_class=T5, + ) + + def set_attn_mode(self, mode): + t5: T5 = self.transformer + for t5block in t5.encoder.block: + t5block: T5Block + t5layer: T5LayerSelfAttention = t5block.layer[0] + t5SaSa: T5Attention = t5layer.SelfAttention + t5SaSa.set_attn_mode(mode) + + +################################################################################################# +### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl +################################################################################################# + + +class T5XXLTokenizer(SDTokenizer): + """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + + def __init__(self): + super().__init__( + pad_with_end=False, + tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), + has_start_token=False, + pad_to_max_length=False, + max_length=99999999, + min_length=77, + ) + + +class T5LayerNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) + self.variance_epsilon = eps + + def forward(self, x): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + return self.weight.to(device=x.device, dtype=x.dtype) * x + + +class T5DenseGatedActDense(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") + hidden_linear = self.wi_1(x) + x = hidden_gelu * hidden_linear + x = self.wo(x) + return x + + +class T5LayerFF(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x): + forwarded_states = self.layer_norm(x) + forwarded_states = self.DenseReluDense(forwarded_states) + x += forwarded_states + return x + + +class T5Attention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) + self.num_heads = num_heads + self.relative_attention_bias = None + if relative_attention_bias: + self.relative_attention_num_buckets = 32 + self.relative_attention_max_distance = 128 + self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device) + + self.attn_mode = "xformers" # TODO 何とかする + + def set_attn_mode(self, mode): + self.attn_mode = mode + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=True, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward(self, x, past_bias=None): + q = self.q(x) + k = self.k(x) + v = self.v(x) + if self.relative_attention_bias is not None: + past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) + if past_bias is not None: + mask = past_bias + out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask, mode=self.attn_mode) + return self.o(out), past_bias + + +class T5LayerSelfAttention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x, past_bias=None): + output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias) + x += output + return x, past_bias + + +class T5Block(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + self.layer = torch.nn.ModuleList() + self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device)) + self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device)) + + def forward(self, x, past_bias=None): + x, past_bias = self.layer[0](x, past_bias) + x = self.layer[-1](x) + return x, past_bias + + +class T5Stack(torch.nn.Module): + def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device): + super().__init__() + self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) + self.block = torch.nn.ModuleList( + [ + T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) + for i in range(num_layers) + ] + ) + self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True): + intermediate = None + x = self.embed_tokens(input_ids) + past_bias = None + for i, l in enumerate(self.block): + # print(i, x.mean(), x.std()) + x, past_bias = l(x, past_bias) + if i == intermediate_output: + intermediate = x.clone() + # print(x.mean(), x.std()) + x = self.final_layer_norm(x) + if intermediate is not None and final_layer_norm_intermediate: + intermediate = self.final_layer_norm(intermediate) + # print(x.mean(), x.std()) + return x, intermediate + + +class T5(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_layers"] + self.encoder = T5Stack( + self.num_layers, + config_dict["d_model"], + config_dict["d_model"], + config_dict["d_ff"], + config_dict["num_heads"], + config_dict["vocab_size"], + dtype, + device, + ) + self.dtype = dtype + + def get_input_embeddings(self): + return self.encoder.embed_tokens + + def set_input_embeddings(self, embeddings): + self.encoder.embed_tokens = embeddings + + def forward(self, *args, **kwargs): + return self.encoder(*args, **kwargs) + + +def create_clip_l(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None): + r""" + state_dict is not loaded, but updated with missing keys + """ + CLIPL_CONFIG = { + "hidden_act": "quick_gelu", + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, + } + with torch.no_grad(): + clip_l = SDClipModel( + layer="hidden", + layer_idx=-2, + device=device, + dtype=dtype, + layer_norm_hidden_state=False, + return_projected_pooled=False, + textmodel_json_config=CLIPL_CONFIG, + ) + if state_dict is not None: + # update state_dict if provided to include logit_scale and text_projection.weight avoid errors + if "logit_scale" not in state_dict: + state_dict["logit_scale"] = clip_l.logit_scale + if "transformer.text_projection.weight" not in state_dict: + state_dict["transformer.text_projection.weight"] = clip_l.transformer.text_projection.weight + return clip_l + + +def create_clip_g(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None): + r""" + state_dict is not loaded, but updated with missing keys + """ + CLIPG_CONFIG = { + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, + } + with torch.no_grad(): + clip_g = SDXLClipG(CLIPG_CONFIG, device=device, dtype=dtype) + if state_dict is not None: + if "logit_scale" not in state_dict: + state_dict["logit_scale"] = clip_g.logit_scale + return clip_g + + +def create_t5xxl(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> T5XXLModel: + T5_CONFIG = {"d_ff": 10240, "d_model": 4096, "num_heads": 64, "num_layers": 24, "vocab_size": 32128} + with torch.no_grad(): + t5 = T5XXLModel(T5_CONFIG, dtype=dtype, device=device) + if state_dict is not None: + if "logit_scale" not in state_dict: + state_dict["logit_scale"] = t5.logit_scale + if "transformer.shared.weight" in state_dict: + state_dict.pop("transformer.shared.weight") + return t5 + + +# endregion diff --git a/library/sd3_utils.py b/library/sd3_utils.py new file mode 100644 index 000000000..6f8c361fd --- /dev/null +++ b/library/sd3_utils.py @@ -0,0 +1,113 @@ +import math +from typing import Dict +import torch + +from library import sd3_models + + +def get_cond( + prompt: str, + tokenizer: sd3_models.SD3Tokenizer, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: sd3_models.T5XXLModel, +): + l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt) + l_out, l_pooled = clip_l.encode_token_weights(l_tokens) + g_out, g_pooled = clip_g.encode_token_weights(g_tokens) + lg_out = torch.cat([l_out, g_out], dim=-1) + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + + if t5_tokens is None: + t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device) + else: + t5_out, t5_pooled = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None + t5_out = t5_out.to(lg_out.dtype) + + return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1) + + +# used if other sd3 models is available +r""" +def get_sd3_configs(state_dict: Dict): + # Important configuration values can be quickly determined by checking shapes in the source file + # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change) + # prefix = "model.diffusion_model." + prefix = "" + + patch_size = state_dict[prefix + "x_embedder.proj.weight"].shape[2] + depth = state_dict[prefix + "x_embedder.proj.weight"].shape[0] // 64 + num_patches = state_dict[prefix + "pos_embed"].shape[1] + pos_embed_max_size = round(math.sqrt(num_patches)) + adm_in_channels = state_dict[prefix + "y_embedder.mlp.0.weight"].shape[1] + context_shape = state_dict[prefix + "context_embedder.weight"].shape + context_embedder_config = { + "target": "torch.nn.Linear", + "params": {"in_features": context_shape[1], "out_features": context_shape[0]}, + } + return { + "patch_size": patch_size, + "depth": depth, + "num_patches": num_patches, + "pos_embed_max_size": pos_embed_max_size, + "adm_in_channels": adm_in_channels, + "context_embedder": context_embedder_config, + } + + +def create_mmdit_from_sd3_checkpoint(state_dict: Dict, attn_mode: str = "xformers"): + "" + Doesn't load state dict. + "" + sd3_configs = get_sd3_configs(state_dict) + + mmdit = sd3_models.MMDiT( + input_size=None, + pos_embed_max_size=sd3_configs["pos_embed_max_size"], + patch_size=sd3_configs["patch_size"], + in_channels=16, + adm_in_channels=sd3_configs["adm_in_channels"], + depth=sd3_configs["depth"], + mlp_ratio=4, + qk_norm=None, + num_patches=sd3_configs["num_patches"], + context_size=4096, + attn_mode=attn_mode, + ) + return mmdit +""" + + +class ModelSamplingDiscreteFlow: + """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models""" + + def __init__(self, shift=1.0): + self.shift = shift + timesteps = 1000 + self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1)) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return sigma * 1000 + + def sigma(self, timestep: torch.Tensor): + timestep = timestep / 1000.0 + if self.shift == 1.0: + return timestep + return self.shift * timestep / (1 + (self.shift - 1) * timestep) + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input - model_output * sigma + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + # assert max_denoise is False, "max_denoise not implemented" + # max_denoise is always True, I'm not sure why it's there + return sigma * noise + (1.0 - sigma) * latent_image diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py new file mode 100644 index 000000000..e14f784d4 --- /dev/null +++ b/sd3_minimal_inference.py @@ -0,0 +1,347 @@ +# Minimum Inference Code for SD3 + +import argparse +import datetime +import math +import os +import random +from typing import Optional, Tuple +import numpy as np + +import torch +from safetensors.torch import safe_open, load_file +from tqdm import tqdm +from PIL import Image + +from library.device_utils import init_ipex, get_preferred_device + +init_ipex() + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sd3_models, sd3_utils + + +def get_noise(seed, latent): + generator = torch.manual_seed(seed) + return torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu").to(latent.dtype) + + +def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): + start = sampling.timestep(sampling.sigma_max) + end = sampling.timestep(sampling.sigma_min) + timesteps = torch.linspace(start, end, steps) + sigs = [] + for x in range(len(timesteps)): + ts = timesteps[x] + sigs.append(sampling.sigma(ts)) + sigs += [0.0] + return torch.FloatTensor(sigs) + + +def max_denoise(model_sampling, sigmas): + max_sigma = float(model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma + + +def do_sample( + height: int, + width: int, + initial_latent: Optional[torch.Tensor], + seed: int, + cond: Tuple[torch.Tensor, torch.Tensor], + neg_cond: Tuple[torch.Tensor, torch.Tensor], + mmdit: sd3_models.MMDiT, + steps: int, + guidance_scale: float, + dtype: torch.dtype, + device: str, +): + if initial_latent is None: + latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 + else: + latent = initial_latent + + latent = latent.to(dtype).to(device) + + noise = get_noise(seed, latent).to(device) + + model_sampling = sd3_utils.ModelSamplingDiscreteFlow() + + sigmas = get_sigmas(model_sampling, steps).to(device) + # sigmas = sigmas[int(steps * (1 - denoise)) :] # do not support i2i + + # conditioning = fix_cond(conditioning) + # neg_cond = fix_cond(neg_cond) + # extra_args = {"cond": cond, "uncond": neg_cond, "cond_scale": guidance_scale} + + noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas)) + + c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype) + y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype) + + x = noise_scaled.to(device).to(dtype) + # print(x.shape) + + with torch.no_grad(): + for i in tqdm(range(len(sigmas) - 1)): + sigma_hat = sigmas[i] + + timestep = model_sampling.timestep(sigma_hat).float() + timestep = torch.FloatTensor([timestep, timestep]).to(device) + + x_c_nc = torch.cat([x, x], dim=0) + # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + model_output = model_output.float() + batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) + + pos_out, neg_out = batched.chunk(2) + denoised = neg_out + (pos_out - neg_out) * guidance_scale + # print(denoised.shape) + + # d = to_d(x, sigma_hat, denoised) + dims_to_append = x.ndim - sigma_hat.ndim + sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] + # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) + """Converts a denoiser output to a Karras ODE derivative.""" + d = (x - denoised) / sigma_hat_dims + + dt = sigmas[i + 1] - sigma_hat + + # Euler method + x = x + d * dt + x = x.to(dtype) + + latent = x + scale_factor = 1.5305 + shift_factor = 0.0609 + # def process_out(self, latent): + # return (latent / self.scale_factor) + self.shift_factor + latent = (latent / scale_factor) + shift_factor + return latent + + +if __name__ == "__main__": + target_height = 1024 + target_width = 1024 + + # steps = 50 # 28 # 50 + guidance_scale = 5 + # seed = 1 # None # 1 + + device = get_preferred_device() + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--clip_g", type=str, required=False) + parser.add_argument("--clip_l", type=str, required=False) + parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--prompt", type=str, default="A photo of a cat") + # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders + parser.add_argument("--negative_prompt", type=str, default="") + parser.add_argument("--output_dir", type=str, default=".") + parser.add_argument("--do_not_use_t5xxl", action="store_true") + parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch") + parser.add_argument("--fp16", action="store_true") + parser.add_argument("--bf16", action="store_true") + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--steps", type=int, default=50) + # parser.add_argument( + # "--lora_weights", + # type=str, + # nargs="*", + # default=[], + # help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)", + # ) + # parser.add_argument("--interactive", action="store_true") + args = parser.parse_args() + + seed = args.seed + steps = args.steps + + sd3_dtype = torch.float32 + if args.fp16: + sd3_dtype = torch.float16 + elif args.bf16: + sd3_dtype = torch.bfloat16 + + # TODO test with separated safetenors files for each model + + # load state dict + logger.info(f"Loading SD3 models from {args.ckpt_path}...") + state_dict = load_file(args.ckpt_path) + + if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_g: remove prefix "text_encoders.clip_g." + logger.info("clip_g is included in the checkpoint") + clip_g_sd = {} + prefix = "text_encoders.clip_g." + for k, v in list(state_dict.items()): + if k.startswith(prefix): + clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + else: + logger.info(f"Lodaing clip_g from {args.clip_g}...") + clip_g_sd = load_file(args.clip_g) + for key in list(clip_g_sd.keys()): + clip_g_sd["transformer." + key] = clip_g_sd.pop(key) + + if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_l: remove prefix "text_encoders.clip_l." + logger.info("clip_l is included in the checkpoint") + clip_l_sd = {} + prefix = "text_encoders.clip_l." + for k, v in list(state_dict.items()): + if k.startswith(prefix): + clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + else: + logger.info(f"Lodaing clip_l from {args.clip_l}...") + clip_l_sd = load_file(args.clip_l) + for key in list(clip_l_sd.keys()): + clip_l_sd["transformer." + key] = clip_l_sd.pop(key) + + if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: + # found t5xxl: remove prefix "text_encoders.t5xxl." + logger.info("t5xxl is included in the checkpoint") + if not args.do_not_use_t5xxl: + t5xxl_sd = {} + prefix = "text_encoders.t5xxl." + for k, v in list(state_dict.items()): + if k.startswith(prefix): + t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + else: + logger.info("but not used") + for key in list(state_dict.keys()): + if key.startswith("text_encoders.t5xxl."): + state_dict.pop(key) + t5xxl_sd = None + elif args.t5xxl: + assert not args.do_not_use_t5xxl, "t5xxl is not used but specified" + logger.info(f"Lodaing t5xxl from {args.t5xxl}...") + t5xxl_sd = load_file(args.t5xxl) + for key in list(t5xxl_sd.keys()): + t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) + else: + logger.info("t5xxl is not used") + t5xxl_sd = None + + use_t5xxl = t5xxl_sd is not None + + # MMDiT and VAE + vae_sd = {} + vae_prefix = "first_stage_model." + mmdit_prefix = "model.diffusion_model." + for k, v in list(state_dict.items()): + if k.startswith(vae_prefix): + vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) + elif k.startswith(mmdit_prefix): + state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k) + + # load tokenizers + logger.info("Loading tokenizers...") + tokenizer = sd3_models.SD3Tokenizer(use_t5xxl) # combined tokenizer + + # load models + # logger.info("Create MMDiT from SD3 checkpoint...") + # mmdit = sd3_utils.create_mmdit_from_sd3_checkpoint(state_dict) + logger.info("Create MMDiT") + mmdit = sd3_models.create_mmdit_sd3_medium_configs(args.attn_mode) + + logger.info("Loading state dict...") + info = mmdit.load_state_dict(state_dict) + logger.info(f"Loaded MMDiT: {info}") + + logger.info(f"Move MMDiT to {device} and {sd3_dtype}...") + mmdit.to(device, dtype=sd3_dtype) + mmdit.eval() + + # load VAE + logger.info("Create VAE") + vae = sd3_models.SDVAE() + logger.info("Loading state dict...") + info = vae.load_state_dict(vae_sd) + logger.info(f"Loaded VAE: {info}") + + logger.info(f"Move VAE to {device} and {sd3_dtype}...") + vae.to(device, dtype=sd3_dtype) + vae.eval() + + # load text encoders + logger.info("Create clip_l") + clip_l = sd3_models.create_clip_l(device, sd3_dtype, clip_l_sd) + + logger.info("Loading state dict...") + info = clip_l.load_state_dict(clip_l_sd) + logger.info(f"Loaded clip_l: {info}") + + logger.info(f"Move clip_l to {device} and {sd3_dtype}...") + clip_l.to(device, dtype=sd3_dtype) + clip_l.eval() + logger.info(f"Set attn_mode to {args.attn_mode}...") + clip_l.set_attn_mode(args.attn_mode) + + logger.info("Create clip_g") + clip_g = sd3_models.create_clip_g(device, sd3_dtype, clip_g_sd) + + logger.info("Loading state dict...") + info = clip_g.load_state_dict(clip_g_sd) + logger.info(f"Loaded clip_g: {info}") + + logger.info(f"Move clip_g to {device} and {sd3_dtype}...") + clip_g.to(device, dtype=sd3_dtype) + clip_g.eval() + logger.info(f"Set attn_mode to {args.attn_mode}...") + clip_g.set_attn_mode(args.attn_mode) + + if use_t5xxl: + logger.info("Create t5xxl") + t5xxl = sd3_models.create_t5xxl(device, sd3_dtype, t5xxl_sd) + + logger.info("Loading state dict...") + info = t5xxl.load_state_dict(t5xxl_sd) + logger.info(f"Loaded t5xxl: {info}") + + logger.info(f"Move t5xxl to {device} and {sd3_dtype}...") + t5xxl.to(device, dtype=sd3_dtype) + # t5xxl.to("cpu", dtype=torch.float32) # run on CPU + t5xxl.eval() + logger.info(f"Set attn_mode to {args.attn_mode}...") + t5xxl.set_attn_mode(args.attn_mode) + else: + t5xxl = None + + # prepare embeddings + logger.info("Encoding prompts...") + # embeds, pooled_embed + cond = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl) + neg_cond = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl) + + # generate image + logger.info("Generating image...") + latent_sampled = do_sample( + target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, guidance_scale, sd3_dtype, device + ) + + # latent to image + with torch.no_grad(): + image = vae.decode(latent_sampled) + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) + decoded_np = decoded_np.astype(np.uint8) + out_image = Image.fromarray(decoded_np) + + # save image + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") + out_image.save(output_path) + + logger.info(f"Saved image to {output_path}") From d53ea22b2a8366e6bc9f14aaeec057cd817f60d3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 23 Jun 2024 23:38:20 +0900 Subject: [PATCH 028/348] sd3 training --- README.md | 25 + library/sai_model_spec.py | 20 +- library/sd3_models.py | 102 ++++- library/sd3_train_utils.py | 544 ++++++++++++++++++++++ library/sd3_utils.py | 211 ++++++++- library/train_util.py | 137 +++++- sd3_minimal_inference.py | 7 +- sd3_train.py | 907 +++++++++++++++++++++++++++++++++++++ 8 files changed, 1909 insertions(+), 44 deletions(-) create mode 100644 library/sd3_train_utils.py create mode 100644 sd3_train.py diff --git a/README.md b/README.md index 946df58f3..34aa2bb2f 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,30 @@ This repository contains training, generation and utility scripts for Stable Diffusion. +## SD3 training + +SD3 training is done with `sd3_train.py`. + +`optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently. + +`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. + +t5xxl doesn't seem to work with `fp16`, so use`bf16` or `fp32`. + +There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. + +```toml +learning_rate = 1e-5 # seems to be too high +optimizer_type = "adafactor" +optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] +cache_text_encoder_outputs = true +cache_text_encoder_outputs_to_disk = true +vae_batch_size = 1 +cache_latents = true +cache_latents_to_disk = true +``` + +--- + [__Change History__](#change-history) is moved to the bottom of the page. 更新履歴は[ページ末尾](#change-history)に移しました。 diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index a63bd82ec..f7bf644d7 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -6,8 +6,10 @@ from typing import List, Optional, Tuple, Union import safetensors from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) r""" @@ -55,11 +57,14 @@ ARCH_SD_V2_512 = "stable-diffusion-v2-512" ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" +ARCH_SD3_M = "stable-diffusion-3-medium" +ARCH_SD3_UNKNOWN = "stable-diffusion-3" ADAPTER_LORA = "lora" ADAPTER_TEXTUAL_INVERSION = "textual-inversion" IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" +IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI" IMPL_DIFFUSERS = "diffusers" PRED_TYPE_EPSILON = "epsilon" @@ -113,7 +118,11 @@ def build_metadata( merged_from: Optional[str] = None, timesteps: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, + sd3: str = None, ): + """ + sd3: only supports "m" + """ # if state_dict is None, hash is not calculated metadata = {} @@ -126,6 +135,11 @@ def build_metadata( if sdxl: arch = ARCH_SD_XL_V1_BASE + elif sd3 is not None: + if sd3 == "m": + arch = ARCH_SD3_M + else: + arch = ARCH_SD3_UNKNOWN elif v2: if v_parameterization: arch = ARCH_SD_V2_768_V @@ -142,7 +156,7 @@ def build_metadata( metadata["modelspec.architecture"] = arch if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: - is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion + is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: # Stable Diffusion ckpt, TI, SDXL LoRA @@ -236,7 +250,7 @@ def build_metadata( # assert all([v is not None for v in metadata.values()]), metadata if not all([v is not None for v in metadata.values()]): logger.error(f"Internal error: some metadata values are None: {metadata}") - + return metadata @@ -250,7 +264,7 @@ def get_title(metadata: dict) -> Optional[str]: def load_metadata_from_safetensors(model: str) -> dict: if not model.endswith(".safetensors"): return {} - + with safetensors.safe_open(model, framework="pt") as f: metadata = f.metadata() if metadata is None: diff --git a/library/sd3_models.py b/library/sd3_models.py index 294a69b06..a4fe400e3 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1,11 +1,13 @@ -# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref +# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref # the original code is licensed under the MIT License # and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! +from ast import Tuple from functools import partial import math -from typing import Dict, Optional +from types import SimpleNamespace +from typing import Dict, List, Optional, Union import einops import numpy as np import torch @@ -106,6 +108,8 @@ def __init__(self, t5xxl=True): self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) self.clip_g = SDXLClipGTokenizer(clip_tokenizer) self.t5xxl = T5XXLTokenizer() if t5xxl else None + # t5xxl has 99999999 max length, clip has 77 + self.model_max_length = self.clip_l.max_length # 77 def tokenize_with_weights(self, text: str): return ( @@ -870,6 +874,10 @@ def __init__( self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels) # self.initialize_weights() + @property + def model_type(self): + return "m" # only support medium + def enable_gradient_checkpointing(self): self.gradient_checkpointing = True for block in self.joint_blocks: @@ -1013,6 +1021,10 @@ def create_mmdit_sd3_medium_configs(attn_mode: str): # endregion # region VAE +# TODO support xformers + +VAE_SCALE_FACTOR = 1.5305 +VAE_SHIFT_FACTOR = 0.0609 def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None): @@ -1222,6 +1234,14 @@ def __init__(self, dtype=torch.float32, device=None): self.encoder = VAEEncoder(dtype=dtype, device=device) self.decoder = VAEDecoder(dtype=dtype, device=device) + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + @torch.autocast("cuda", dtype=torch.float16) def decode(self, latent): return self.decoder(latent) @@ -1234,6 +1254,43 @@ def encode(self, image): std = torch.exp(0.5 * logvar) return mean + std * torch.randn_like(mean) + @staticmethod + def process_in(latent): + return (latent - VAE_SHIFT_FACTOR) * VAE_SCALE_FACTOR + + @staticmethod + def process_out(latent): + return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR + + +class VAEOutput: + def __init__(self, latent): + self.latent = latent + + @property + def latent_dist(self): + return self + + def sample(self): + return self.latent + + +class VAEWrapper: + def __init__(self, vae): + self.vae = vae + + @property + def device(self): + return self.vae.device + + @property + def dtype(self): + return self.vae.dtype + + # latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + def encode(self, image): + return VAEOutput(self.vae.encode(image)) + # endregion @@ -1370,15 +1427,39 @@ def forward(self, *args, **kwargs): class ClipTokenWeightEncoder: - def encode_token_weights(self, token_weight_pairs): - tokens = list(map(lambda a: a[0], token_weight_pairs[0])) - out, pooled = self([tokens]) - if pooled is not None: - first_pooled = pooled[0:1].cpu() + # def encode_token_weights(self, token_weight_pairs): + # tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + # out, pooled = self([tokens]) + # if pooled is not None: + # first_pooled = pooled[0:1] + # else: + # first_pooled = pooled + # output = [out[0:1]] + # return torch.cat(output, dim=-2), first_pooled + + # fix to support batched inputs + # : Union[List[Tuple[torch.Tensor, torch.Tensor]], List[List[Tuple[torch.Tensor, torch.Tensor]]]] + def encode_token_weights(self, list_of_token_weight_pairs): + has_batch = isinstance(list_of_token_weight_pairs[0][0], list) + + if has_batch: + list_of_tokens = [] + for pairs in list_of_token_weight_pairs: + tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] + list_of_tokens.append(tokens) else: - first_pooled = pooled - output = [out[0:1]] - return torch.cat(output, dim=-2).cpu(), first_pooled + list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] + + out, pooled = self(list_of_tokens) + if has_batch: + return out, pooled + else: + if pooled is not None: + first_pooled = pooled[0:1] + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2), first_pooled class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): @@ -1694,6 +1775,7 @@ def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermed x = self.embed_tokens(input_ids) past_bias = None for i, l in enumerate(self.block): + # uncomment to debug layerwise output: fp16 may cause issues # print(i, x.mean(), x.std()) x, past_bias = l(x, past_bias) if i == intermediate_output: diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py new file mode 100644 index 000000000..4e45871f4 --- /dev/null +++ b/library/sd3_train_utils.py @@ -0,0 +1,544 @@ +import argparse +import math +import os +from typing import Optional, Tuple + +import torch +from safetensors.torch import save_file + +from library import sd3_models, sd3_utils, train_util +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate import init_empty_weights +from tqdm import tqdm + +# from transformers import CLIPTokenizer +# from library import model_util +# , sdxl_model_util, train_util, sdxl_original_unet +# from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from .sdxl_train_util import match_mixed_precision + + +def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype) -> Tuple[ + sd3_models.MMDiT, + Optional[sd3_models.SDClipModel], + Optional[sd3_models.SDXLClipG], + Optional[sd3_models.T5XXLModel], + sd3_models.SDVAE, +]: + model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 + + for pi in range(accelerator.state.num_processes): + if pi == accelerator.state.local_process_index: + logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + + mmdit, clip_l, clip_g, t5xxl, vae = sd3_utils.load_models( + args.pretrained_model_name_or_path, + args.clip_l, + args.clip_g, + args.t5xxl, + args.vae, + attn_mode, + accelerator.device if args.lowram else "cpu", + weight_dtype, + args.disable_mmap_load_safetensors, + t5xxl_device, + t5xxl_dtype, + ) + + # work on low-ram device + if args.lowram: + if clip_l is not None: + clip_l.to(accelerator.device) + if clip_g is not None: + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(accelerator.device) + vae.to(accelerator.device) + mmdit.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + accelerator.wait_for_everyone() + + return mmdit, clip_l, clip_g, t5xxl, vae + + +def save_models( + ckpt_path: str, + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel], + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, +): + r""" + Save models to checkpoint file. Only supports unified checkpoint format. + """ + + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("model.diffusion_model.", mmdit.state_dict()) + update_sd("first_stage_model.", vae.state_dict()) + + if clip_l is not None: + update_sd("text_encoders.clip_l.", clip_l.state_dict()) + if clip_g is not None: + update_sd("text_encoders.clip_g.", clip_g.state_dict()) + if t5xxl is not None: + update_sd("text_encoders.t5xxl.", t5xxl.state_dict()) + + save_file(state_dict, ckpt_path, metadata=sai_metadata) + + +def save_sd3_model_on_train_end( + args: argparse.Namespace, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel], + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec( + None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type + ) + save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_sd3_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel], + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec( + None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type + ) + save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +def add_sd3_training_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", + ) + + parser.add_argument( + "--clip_l", + type=str, + required=False, + help="CLIP-L model path. if not specified, use ckpt's state_dict / CLIP-Lモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--clip_g", + type=str, + required=False, + help="CLIP-G model path. if not specified, use ckpt's state_dict / CLIP-Gモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--t5xxl", + type=str, + required=False, + help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--save_clip", action="store_true", help="save CLIP models to checkpoint / CLIPモデルをチェックポイントに保存する" + ) + parser.add_argument( + "--save_t5xxl", action="store_true", help="save T5-XXL model to checkpoint / T5-XXLモデルをチェックポイントに保存する" + ) + + parser.add_argument( + "--t5xxl_device", + type=str, + default=None, + help="T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用", + ) + parser.add_argument( + "--t5xxl_dtype", + type=str, + default=None, + help="T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用", + ) + + # copy from Diffusers + parser.add_argument( + "--weighting_scheme", + type=str, + default="logit_normal", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + + +def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): + assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" + if args.v_parameterization: + logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") + + if args.clip_skip is not None: + logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") + + # if args.multires_noise_iterations: + # logger.info( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" + # ) + # else: + # if args.noise_offset is None: + # args.noise_offset = DEFAULT_NOISE_OFFSET + # elif args.noise_offset != DEFAULT_NOISE_OFFSET: + # logger.info( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" + # ) + # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") + + assert ( + not hasattr(args, "weighted_captions") or not args.weighted_captions + ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" + + if supportTextEncoderCaching: + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + args.cache_text_encoder_outputs = True + logger.warning( + "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " + + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" + ) + + +def sample_images(*args, **kwargs): + return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) + + +# region Diffusers + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils.torch_utils import randn_tensor +from diffusers.utils import BaseOutput + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + ): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps) + + sigmas = timesteps / self.config.num_train_timesteps + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + + timesteps = sigmas * self.config.num_train_timesteps + self.timesteps = timesteps.to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + + # if self.config.prediction_type == "vector_field": + + denoised = sample - model_output * sigma + # 2. Convert to an ODE derivative + derivative = (sample - denoised) / sigma_hat + + dt = self.sigmas[self.step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps + + +# endregion diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 6f8c361fd..c2c914123 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -1,30 +1,226 @@ import math -from typing import Dict +from typing import Dict, Optional, Union import torch +import safetensors +from safetensors.torch import load_file +from accelerate import init_empty_weights + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) from library import sd3_models +# TODO move some of functions to model_util.py +from library import sdxl_model_util + +# region models + + +def load_models( + ckpt_path: str, + clip_l_path: str, + clip_g_path: str, + t5xxl_path: str, + vae_path: str, + attn_mode: str, + device: Union[str, torch.device], + weight_dtype: torch.dtype, + disable_mmap: bool = False, + t5xxl_device: Optional[str] = None, + t5xxl_dtype: Optional[str] = None, +): + def load_state_dict(path: str, dvc: Union[str, torch.device] = device): + if disable_mmap: + return safetensors.torch.load(open(path, "rb").read()) + else: + try: + return load_file(path, device=dvc) + except: + return load_file(path) # prevent device invalid Error + + t5xxl_device = t5xxl_device or device + + logger.info(f"Loading SD3 models from {ckpt_path}...") + state_dict = load_state_dict(ckpt_path) + + # load clip_l + clip_l_sd = None + if clip_l_path: + logger.info(f"Loading clip_l from {clip_l_path}...") + clip_l_sd = load_state_dict(clip_l_path) + for key in list(clip_l_sd.keys()): + clip_l_sd["transformer." + key] = clip_l_sd.pop(key) + else: + if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_l: remove prefix "text_encoders.clip_l." + logger.info("clip_l is included in the checkpoint") + clip_l_sd = {} + prefix = "text_encoders.clip_l." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + + # load clip_g + clip_g_sd = None + if clip_g_path: + logger.info(f"Loading clip_g from {clip_g_path}...") + clip_g_sd = load_state_dict(clip_g_path) + for key in list(clip_g_sd.keys()): + clip_g_sd["transformer." + key] = clip_g_sd.pop(key) + else: + if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_g: remove prefix "text_encoders.clip_g." + logger.info("clip_g is included in the checkpoint") + clip_g_sd = {} + prefix = "text_encoders.clip_g." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + + # load t5xxl + t5xxl_sd = None + if t5xxl_path: + logger.info(f"Loading t5xxl from {t5xxl_path}...") + t5xxl_sd = load_state_dict(t5xxl_path, t5xxl_device) + for key in list(t5xxl_sd.keys()): + t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) + else: + if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: + # found t5xxl: remove prefix "text_encoders.t5xxl." + logger.info("t5xxl is included in the checkpoint") + t5xxl_sd = {} + prefix = "text_encoders.t5xxl." + for k in list(state_dict.keys()): + if k.startswith(prefix): + t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + + # MMDiT and VAE + vae_sd = {} + if vae_path: + logger.info(f"Loading VAE from {vae_path}...") + vae_sd = load_state_dict(vae_path) + else: + # remove prefix "first_stage_model." + vae_sd = {} + vae_prefix = "first_stage_model." + for k in list(state_dict.keys()): + if k.startswith(vae_prefix): + vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) + + mmdit_prefix = "model.diffusion_model." + for k in list(state_dict.keys()): + if k.startswith(mmdit_prefix): + state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k) + else: + state_dict.pop(k) # remove other keys + + # load MMDiT + logger.info("Building MMDit") + with init_empty_weights(): + mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) + + logger.info("Loading state dict...") + info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) + logger.info(f"Loaded MMDiT: {info}") + + # load ClipG and ClipL + if clip_l_sd is None: + clip_l = None + else: + logger.info("Building ClipL") + clip_l = sd3_models.create_clip_l(device, weight_dtype, clip_l_sd) + logger.info("Loading state dict...") + info = clip_l.load_state_dict(clip_l_sd) + logger.info(f"Loaded ClipL: {info}") + clip_l.set_attn_mode(attn_mode) + + if clip_g_sd is None: + clip_g = None + else: + logger.info("Building ClipG") + clip_g = sd3_models.create_clip_g(device, weight_dtype, clip_g_sd) + logger.info("Loading state dict...") + info = clip_g.load_state_dict(clip_g_sd) + logger.info(f"Loaded ClipG: {info}") + clip_g.set_attn_mode(attn_mode) + + # load T5XXL + if t5xxl_sd is None: + t5xxl = None + else: + logger.info("Building T5XXL") + t5xxl = sd3_models.create_t5xxl(t5xxl_device, t5xxl_dtype, t5xxl_sd) + logger.info("Loading state dict...") + info = t5xxl.load_state_dict(t5xxl_sd) + logger.info(f"Loaded T5XXL: {info}") + t5xxl.set_attn_mode(attn_mode) + + # load VAE + logger.info("Building VAE") + vae = sd3_models.SDVAE() + logger.info("Loading state dict...") + info = vae.load_state_dict(vae_sd) + logger.info(f"Loaded VAE: {info}") + + return mmdit, clip_l, clip_g, t5xxl, vae + + +# endregion +# region utils + def get_cond( prompt: str, tokenizer: sd3_models.SD3Tokenizer, clip_l: sd3_models.SDClipModel, clip_g: sd3_models.SDXLClipG, - t5xxl: sd3_models.T5XXLModel, + t5xxl: Optional[sd3_models.T5XXLModel] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt) + return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype) + + +def get_cond_from_tokens( + l_tokens, + g_tokens, + t5_tokens, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): l_out, l_pooled = clip_l.encode_token_weights(l_tokens) g_out, g_pooled = clip_g.encode_token_weights(g_tokens) lg_out = torch.cat([l_out, g_out], dim=-1) lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + if device is not None: + lg_out = lg_out.to(device=device) + l_pooled = l_pooled.to(device=device) + g_pooled = g_pooled.to(device=device) + if dtype is not None: + lg_out = lg_out.to(dtype=dtype) + l_pooled = l_pooled.to(dtype=dtype) + g_pooled = g_pooled.to(dtype=dtype) + # t5xxl may be in another device (eg. cpu) if t5_tokens is None: - t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device) + t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype) else: - t5_out, t5_pooled = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None - t5_out = t5_out.to(lg_out.dtype) + t5_out, _ = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None + if device is not None: + t5_out = t5_out.to(device=device) + if dtype is not None: + t5_out = t5_out.to(dtype=dtype) - return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1) + # return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1) + return lg_out, t5_out, torch.cat((l_pooled, g_pooled), dim=-1) # used if other sd3 models is available @@ -111,3 +307,6 @@ def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): # assert max_denoise is False, "max_denoise not implemented" # max_denoise is always True, I'm not sure why it's there return sigma * noise + (1.0 - sigma) * latent_image + + +# endregion diff --git a/library/train_util.py b/library/train_util.py index 4736ff4ff..c67e8737c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -58,7 +58,7 @@ KDPM2AncestralDiscreteScheduler, AutoencoderKL, ) -from library import custom_train_functions +from library import custom_train_functions, sd3_utils from library.original_unet import UNet2DConditionModel from huggingface_hub import hf_hub_download import numpy as np @@ -135,6 +135,7 @@ ) TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" class ImageInfo: @@ -985,7 +986,7 @@ def is_text_encoder_output_cacheable(self): ] ) - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching latents.") @@ -1006,7 +1007,7 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # check disk cache exists and size of latents if cache_to_disk: - info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" + info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix if not is_main_process: # store to info only continue @@ -1040,14 +1041,43 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc for batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) - # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる - # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する - # SD1/2に対応するにはv2のフラグを持つ必要があるので後回し + # if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype + # this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset + # to support SD1/2, it needs a flag for v2, but it is postponed def cache_text_encoder_outputs( - self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True + self, tokenizers, text_encoders, device, output_dtype, cache_to_disk=False, is_main_process=True ): assert len(tokenizers) == 2, "only support SDXL" + return self.cache_text_encoder_outputs_common( + tokenizers, text_encoders, [device, device], output_dtype, [output_dtype], cache_to_disk, is_main_process + ) + # same as above, but for SD3 + def cache_text_encoder_outputs_sd3( + self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True + ): + return self.cache_text_encoder_outputs_common( + [tokenizer], + text_encoders, + devices, + output_dtype, + te_dtypes, + cache_to_disk, + is_main_process, + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3, + ) + + def cache_text_encoder_outputs_common( + self, + tokenizers, + text_encoders, + devices, + output_dtype, + te_dtypes, + cache_to_disk=False, + is_main_process=True, + file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX, + ): # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching text encoder outputs.") @@ -1058,13 +1088,14 @@ def cache_text_encoder_outputs( for info in tqdm(image_infos): # subset = self.image_to_subset[info.image_key] if cache_to_disk: - te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX + te_out_npz = os.path.splitext(info.absolute_path)[0] + file_suffix info.text_encoder_outputs_npz = te_out_npz if not is_main_process: # store to info only continue if os.path.exists(te_out_npz): + # TODO check varidity of cache here continue image_infos_to_cache.append(info) @@ -1073,18 +1104,23 @@ def cache_text_encoder_outputs( return # prepare tokenizers and text encoders - for text_encoder in text_encoders: + for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes): text_encoder.to(device) - if weight_dtype is not None: - text_encoder.to(dtype=weight_dtype) + if te_dtype is not None: + text_encoder.to(dtype=te_dtype) # create batch + is_sd3 = len(tokenizers) == 1 batch = [] batches = [] for info in image_infos_to_cache: - input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) - input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) - batch.append((info, input_ids1, input_ids2)) + if not is_sd3: + input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) + input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) + batch.append((info, input_ids1, input_ids2)) + else: + l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption) + batch.append((info, l_tokens, g_tokens, t5_tokens)) if len(batch) >= self.batch_size: batches.append(batch) @@ -1095,13 +1131,32 @@ def cache_text_encoder_outputs( # iterate batches: call text encoder and cache outputs for memory or disk logger.info("caching text encoder outputs...") - for batch in tqdm(batches): - infos, input_ids1, input_ids2 = zip(*batch) - input_ids1 = torch.stack(input_ids1, dim=0) - input_ids2 = torch.stack(input_ids2, dim=0) - cache_batch_text_encoder_outputs( - infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype - ) + if not is_sd3: + for batch in tqdm(batches): + infos, input_ids1, input_ids2 = zip(*batch) + input_ids1 = torch.stack(input_ids1, dim=0) + input_ids2 = torch.stack(input_ids2, dim=0) + cache_batch_text_encoder_outputs( + infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, output_dtype + ) + else: + for batch in tqdm(batches): + infos, l_tokens, g_tokens, t5_tokens = zip(*batch) + + # stack tokens + # l_tokens = [tokens[0] for tokens in l_tokens] + # g_tokens = [tokens[0] for tokens in g_tokens] + # t5_tokens = [tokens[0] for tokens in t5_tokens] + + cache_batch_text_encoder_outputs_sd3( + infos, + tokenizers[0], + text_encoders, + self.max_token_length, + cache_to_disk, + (l_tokens, g_tokens, t5_tokens), + output_dtype, + ) def get_image_size(self, image_path): return imagesize.get(image_path) @@ -1332,6 +1387,7 @@ def __getitem__(self, index): captions.append(caption) if not self.token_padding_disabled: # this option might be omitted in future + # TODO get_input_ids must support SD3 if self.XTI_layers: token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) else: @@ -2140,10 +2196,10 @@ def enable_XTI(self, *args, **kwargs): for dataset in self.datasets: dataset.enable_XTI(*args, **kwargs) - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True @@ -2152,6 +2208,15 @@ def cache_text_encoder_outputs( logger.info(f"[Dataset {i}]") dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) + def cache_text_encoder_outputs_sd3( + self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True + ): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.cache_text_encoder_outputs_sd3( + tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process + ) + def set_caching_mode(self, caching_mode): for dataset in self.datasets: dataset.set_caching_mode(caching_mode) @@ -2585,6 +2650,30 @@ def cache_batch_text_encoder_outputs( info.text_encoder_pool2 = pool2 +def cache_batch_text_encoder_outputs_sd3( + image_infos, tokenizer, text_encoders, max_token_length, cache_to_disk, input_ids, output_dtype +): + # make input_ids for each text encoder + l_tokens, g_tokens, t5_tokens = input_ids + + clip_l, clip_g, t5xxl = text_encoders + with torch.no_grad(): + b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens( + l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, "cpu", output_dtype + ) + b_lg_out = b_lg_out.detach() + b_t5_out = b_t5_out.detach() + b_pool = b_pool.detach() + + for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): + if cache_to_disk: + save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) + else: + info.text_encoder_outputs1 = lg_out + info.text_encoder_outputs2 = t5_out + info.text_encoder_pool2 = pool + + def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2): np.savez( npz_path, @@ -2907,6 +2996,7 @@ def get_sai_model_spec( lora: bool, textual_inversion: bool, is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA + sd3: str = None, ): timestamp = time.time() @@ -2940,6 +3030,7 @@ def get_sai_model_spec( tags=args.metadata_tags, timesteps=timesteps, clip_skip=args.clip_skip, # None or int + sd3=sd3, ) return metadata diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index e14f784d4..96e9da4ac 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -320,8 +320,11 @@ def do_sample( # prepare embeddings logger.info("Encoding prompts...") # embeds, pooled_embed - cond = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl) - neg_cond = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl) + lg_out, t5_out, pooled = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl) + cond = torch.cat([lg_out, t5_out], dim=-2), pooled + + lg_out, t5_out, pooled = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl) + neg_cond = torch.cat([lg_out, t5_out], dim=-2), pooled # generate image logger.info("Generating image...") diff --git a/sd3_train.py b/sd3_train.py new file mode 100644 index 000000000..0721b2ae4 --- /dev/null +++ b/sd3_train.py @@ -0,0 +1,907 @@ +# training with captions + +import argparse +import copy +import math +import os +from multiprocessing import Value +from typing import List +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + + +init_ipex() + +from accelerate.utils import set_seed +from diffusers import DDPMScheduler +from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils + +# , sdxl_model_util + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions + +# from library.custom_train_functions import ( +# apply_snr_weight, +# prepare_scheduler_for_custom_training, +# scale_v_prediction_loss_like_noise_prediction, +# add_v_prediction_like_loss, +# apply_debiased_estimation, +# apply_masked_loss, +# ) + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + assert ( + not args.weighted_captions + ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + assert ( + not args.train_text_encoder or not args.cache_text_encoder_outputs + ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + + # if args.block_lr: + # block_lrs = [float(lr) for lr in args.block_lr.split(",")] + # assert ( + # len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR + # ), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください" + # else: + # block_lrs = None + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # load tokenizer + sd3_tokenizer = sd3_models.SD3Tokenizer() + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=[sd3_tokenizer]) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, [sd3_tokenizer]) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(8) # TODO これでいいか確認 + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = weight_dtype # torch.float32 if args.no_half_vae else weight_dtype # SD3 VAE works with fp16 + + t5xxl_dtype = weight_dtype + if args.t5xxl_dtype is not None: + if args.t5xxl_dtype == "fp16": + t5xxl_dtype = torch.float16 + elif args.t5xxl_dtype == "bf16": + t5xxl_dtype = torch.bfloat16 + elif args.t5xxl_dtype == "fp32" or args.t5xxl_dtype == "float": + t5xxl_dtype = torch.float32 + else: + raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") + t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device + + # モデルを読み込む + attn_mode = "xformers" if args.xformers else "torch" + + assert ( + attn_mode == "torch" + ), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + + mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( + args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype + ) + assert clip_l is not None, "clip_l is required / clip_lは必須です" + assert clip_g is not None, "clip_g is required / clip_gは必須です" + # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible + with torch.no_grad(): + train_dataset_group.cache_latents( + vae_wrapper, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process, file_suffix="_sd3.npz" + ) + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # 学習を準備する:モデルを適切な状態にする + if args.gradient_checkpointing: + mmdit.enable_gradient_checkpointing() + train_mmdit = args.learning_rate != 0 + train_clip_l = False + train_clip_g = False + train_t5xxl = False + + # if args.train_text_encoder: + # # TODO each option for two text encoders? + # accelerator.print("enable text encoder training") + # if args.gradient_checkpointing: + # text_encoder1.gradient_checkpointing_enable() + # text_encoder2.gradient_checkpointing_enable() + # lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + # lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train + # train_clip_l = lr_te1 != 0 + # train_clip_g = lr_te2 != 0 + + # # caching one text encoder output is not supported + # if not train_clip_l: + # text_encoder1.to(weight_dtype) + # if not train_clip_g: + # text_encoder2.to(weight_dtype) + # text_encoder1.requires_grad_(train_clip_l) + # text_encoder2.requires_grad_(train_clip_g) + # text_encoder1.train(train_clip_l) + # text_encoder2.train(train_clip_g) + # else: + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + clip_l.requires_grad_(False) + clip_g.requires_grad_(False) + clip_l.eval() + clip_g.eval() + if t5xxl is not None: + t5xxl.to(t5xxl_dtype) + t5xxl.requires_grad_(False) + t5xxl.eval() + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + + with torch.no_grad(), accelerator.autocast(): + train_dataset_group.cache_text_encoder_outputs_sd3( + sd3_tokenizer, + (clip_l, clip_g, t5xxl), + (accelerator.device, accelerator.device, t5xxl_device), + None, + (None, None, None), + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) + accelerator.wait_for_everyone() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared + + training_models = [] + params_to_optimize = [] + # if train_unet: + training_models.append(mmdit) + # if block_lrs is None: + params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate}) + # else: + # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) + + # if train_clip_l: + # training_models.append(text_encoder1) + # params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + # if train_clip_g: + # training_models.append(text_encoder2) + # params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"train mmdit: {train_mmdit}") # , text_encoder1: {train_clip_l}, text_encoder2: {train_clip_g}") + accelerator.print(f"number of models: {len(training_models)}") + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.fused_optimizer_groups: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # This balances memory usage and management complexity. + + # calculate total number of parameters + n_total_params = sum(len(params["params"]) for params in params_to_optimize) + params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) + + # split params into groups, keeping the learning rate the same for all params in a group + # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + grouped_params = [] + param_group = [] + param_group_lr = -1 + for group in params_to_optimize: + lr = group["lr"] + for p in group["params"]: + # if the learning rate is different for different params, start a new group + if lr != param_group_lr: + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = lr + + param_group.append(p) + + # if the group has enough parameters, start a new group + if len(param_group) == params_per_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = -1 + + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + # dataloaderを準備する + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.fused_optimizer_groups: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + mmdit.to(weight_dtype) + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + if t5xxl is not None: + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + mmdit.to(weight_dtype) + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + if t5xxl is not None: + t5xxl.to(weight_dtype) + + # TODO check if this is necessary. SD3 uses pool for clip_l and clip_g + # # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer + # if train_clip_l: + # text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) + # text_encoder1.text_model.final_layer_norm.requires_grad_(False) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model( + args, + mmdit=mmdit, + # mmdie=mmdit if train_mmdit else None, + # text_encoder1=text_encoder1 if train_clip_l else None, + # text_encoder2=text_encoder2 if train_clip_g else None, + ) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # acceleratorがなんかよろしくやってくれるらしい + if train_mmdit: + mmdit = accelerator.prepare(mmdit) + # if train_clip_l: + # text_encoder1 = accelerator.prepare(text_encoder1) + # if train_clip_g: + # text_encoder2 = accelerator.prepare(text_encoder2) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + clip_l.to("cpu", dtype=torch.float32) + clip_g.to("cpu", dtype=torch.float32) + if t5xxl is not None: + t5xxl.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + # TODO support CPU for text encoders + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(accelerator.device) + + # TODO cache sample prompt's embeddings to free text encoder's memory + if args.cache_text_encoder_outputs: + if not args.save_t5xxl: + t5xxl = None # free memory + clean_memory_on_device(accelerator.device) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + elif args.fused_optimizer_groups: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def optimizer_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(optimizer_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + # noise_scheduler = DDPMScheduler( + # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + # ) + + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + # prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + # if args.zero_terminal_snr: + # custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + # # For --sample_at_first + # sd3_train_utils.sample_images( + # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], mmdit + # ) + + # following function will be moved to sd3_train_utils + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None + ): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + loss_recorder = train_util.LossRecorder() + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.fused_optimizer_groups: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = vae.encode(batch["images"].to(vae_dtype)).to(weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + # latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + latents = sd3_models.SDVAE.process_in(latents) + + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + # not cached, get text encoder outputs + # XXX This does not work yet + input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl = batch["input_ids"] + with torch.set_grad_enabled(args.train_text_encoder): + # TODO support weighted captions + # TODO support length > 75 + input_ids_clip_l = input_ids_clip_l.to(accelerator.device) + input_ids_clip_g = input_ids_clip_g.to(accelerator.device) + input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) + + # get text encoder outputs: outputs are concatenated + context, pool = sd3_utils.get_cond_from_tokens( + input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl, clip_l, clip_g, t5xxl + ) + else: + # encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) + # encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) + # pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + # TODO this reuses SDXL keys, it should be fixed + lg_out = batch["text_encoder_outputs1_list"] + t5_out = batch["text_encoder_outputs2_list"] + pool = batch["text_encoder_pool2_list"] + context = torch.cat([lg_out, t5_out], dim=-2) + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) + + # Add noise according to flow matching. + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + # call model + with accelerator.autocast(): + model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool) + + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = latents + + # Compute regular loss. TODO simplify this + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.fused_optimizer_groups): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # sdxl_train_util.sample_images( + # accelerator, + # args, + # None, + # global_step, + # accelerator.device, + # vae, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # mmdit, + # ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + sd3_train_utils.save_sd3_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + clip_l if args.save_clip else None, + clip_g if args.save_clip else None, + t5xxl if args.save_t5xxl else None, + mmdit, + vae, + ) + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_mmdit) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + sd3_train_utils.save_sd3_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + clip_l if args.save_clip else None, + clip_g if args.save_clip else None, + t5xxl if args.save_t5xxl else None, + mmdit, + vae, + ) + + # sdxl_train_util.sample_images( + # accelerator, + # args, + # epoch + 1, + # global_step, + # accelerator.device, + # vae, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # mmdit, + # ) + + is_main_process = accelerator.is_main_process + # if is_main_process: + mmdit = accelerator.unwrap_model(mmdit) + clip_l = accelerator.unwrap_model(clip_l) + clip_g = accelerator.unwrap_model(clip_g) + if t5xxl is not None: + t5xxl = accelerator.unwrap_model(t5xxl) + + accelerator.end_training() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + sd3_train_utils.save_sd3_model_on_train_end( + args, + save_dtype, + epoch, + global_step, + clip_l if args.save_clip else None, + clip_g if args.save_clip else None, + t5xxl if args.save_t5xxl else None, + mmdit, + vae, + ) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sd3_train_utils.add_sd3_training_arguments(parser) + + # TE training is disabled temporarily + + # parser.add_argument( + # "--learning_rate_te1", + # type=float, + # default=None, + # help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率", + # ) + # parser.add_argument( + # "--learning_rate_te2", + # type=float, + # default=None, + # help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", + # ) + + # parser.add_argument( + # "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + # ) + # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + # parser.add_argument( + # "--no_half_vae", + # action="store_true", + # help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + # ) + # parser.add_argument( + # "--block_lr", + # type=str, + # default=None, + # help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + # + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", + # ) + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) From 0fe4eafac996fa5139a311aadc86aca28ddc6930 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 24 Jun 2024 23:12:48 +0900 Subject: [PATCH 029/348] fix to use zero for initial latent --- sd3_minimal_inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index 96e9da4ac..7f5f28cea 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -64,7 +64,8 @@ def do_sample( device: str, ): if initial_latent is None: - latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 + # latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 + latent = torch.zeros(1, 16, height // 8, width // 8, device=device) else: latent = initial_latent From 4802e4aaec74429f733fae289e41c5618ebb0e92 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 24 Jun 2024 23:13:14 +0900 Subject: [PATCH 030/348] workaround for long caption ref #1382 --- library/sd3_models.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index a4fe400e3..c19aec6aa 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -56,7 +56,7 @@ def __init__( self.inv_vocab = {v: k for k, v in vocab.items()} self.max_word_length = 8 - def tokenize_with_weights(self, text: str): + def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None): """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" """ @@ -79,6 +79,14 @@ def tokenize_with_weights(self, text: str): batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) if self.min_length is not None and len(batch) < self.min_length: batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) + + # truncate to max_length + # print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}") + if truncate_to_max_length and len(batch) > self.max_length: + batch = batch[: self.max_length] + if truncate_length is not None and len(batch) > truncate_length: + batch = batch[:truncate_length] + return [batch] @@ -112,10 +120,15 @@ def __init__(self, t5xxl=True): self.model_max_length = self.clip_l.max_length # 77 def tokenize_with_weights(self, text: str): + # temporary truncate to max_length even for t5xxl return ( self.clip_l.tokenize_with_weights(text), self.clip_g.tokenize_with_weights(text), - self.t5xxl.tokenize_with_weights(text) if self.t5xxl is not None else None, + ( + self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length) + if self.t5xxl is not None + else None + ), ) From 8f2ba27869e4c5b9225a309aeed275a47d8eed6a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Jun 2024 20:36:22 +0900 Subject: [PATCH 031/348] support text_encoder_batch_size for caching --- library/sd3_train_utils.py | 7 +++++++ library/train_util.py | 14 ++++++++++---- sd3_train.py | 1 + 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 4e45871f4..70c83c0ba 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -173,6 +173,13 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): action="store_true", help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", ) + parser.add_argument( + "--text_encoder_batch_size", + type=int, + default=None, + help="text encoder batch size (default: None, use dataset's batch size)" + + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", + ) parser.add_argument( "--disable_mmap_load_safetensors", action="store_true", diff --git a/library/train_util.py b/library/train_util.py index c67e8737c..96d32e3bc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1054,7 +1054,7 @@ def cache_text_encoder_outputs( # same as above, but for SD3 def cache_text_encoder_outputs_sd3( - self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True + self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None ): return self.cache_text_encoder_outputs_common( [tokenizer], @@ -1065,6 +1065,7 @@ def cache_text_encoder_outputs_sd3( cache_to_disk, is_main_process, TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3, + batch_size, ) def cache_text_encoder_outputs_common( @@ -1077,10 +1078,15 @@ def cache_text_encoder_outputs_common( cache_to_disk=False, is_main_process=True, file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX, + batch_size=None, ): # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching text encoder outputs.") + + if batch_size is None: + batch_size = self.batch_size + image_infos = list(self.image_data.values()) logger.info("checking cache existence...") @@ -1122,7 +1128,7 @@ def cache_text_encoder_outputs_common( l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption) batch.append((info, l_tokens, g_tokens, t5_tokens)) - if len(batch) >= self.batch_size: + if len(batch) >= batch_size: batches.append(batch) batch = [] @@ -2209,12 +2215,12 @@ def cache_text_encoder_outputs( dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) def cache_text_encoder_outputs_sd3( - self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True + self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None ): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") dataset.cache_text_encoder_outputs_sd3( - tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process + tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size ) def set_caching_mode(self, caching_mode): diff --git a/sd3_train.py b/sd3_train.py index 0721b2ae4..8216a62b3 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -254,6 +254,7 @@ def train(args): (None, None, None), args.cache_text_encoder_outputs_to_disk, accelerator.is_main_process, + args.text_encoder_batch_size, ) accelerator.wait_for_everyone() From 828a581e2968935c00d22e7e03ca32c1281aa5dd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Jun 2024 20:43:31 +0900 Subject: [PATCH 032/348] fix assertion for experimental impl ref #1389 --- sd3_train.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index 8216a62b3..ea9a11049 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -60,9 +60,19 @@ def train(args): assert ( not args.weighted_captions ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + # assert ( + # not args.train_text_encoder or not args.cache_text_encoder_outputs + # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + + # training text encoder is not supported + assert ( + not args.train_text_encoder + ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" + + # training without text encoder cache is not supported assert ( - not args.train_text_encoder or not args.cache_text_encoder_outputs - ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + args.cache_text_encoder_outputs + ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" # if args.block_lr: # block_lrs = [float(lr) for lr in args.block_lr.split(",")] From 381598c8bbd3d4e50ec4327fa27d5d0072ec2a67 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Jun 2024 21:15:02 +0900 Subject: [PATCH 033/348] fix resolution in metadata for sd3 --- library/sai_model_spec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index f7bf644d7..af073677e 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -216,7 +216,7 @@ def build_metadata( reso = (reso[0], reso[0]) else: # resolution is defined in dataset, so use default - if sdxl: + if sdxl or sd3 is not None: reso = 1024 elif v2 and v_parameterization: reso = 768 From 66cf43547972647389fbd2addb53cff2ab478660 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 27 Jun 2024 13:14:09 +0900 Subject: [PATCH 034/348] re-fix assertion ref #1389 --- sd3_train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index ea9a11049..b6c932c4c 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -64,10 +64,10 @@ def train(args): # not args.train_text_encoder or not args.cache_text_encoder_outputs # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" - # training text encoder is not supported - assert ( - not args.train_text_encoder - ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" + # # training text encoder is not supported + # assert ( + # not args.train_text_encoder + # ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" # training without text encoder cache is not supported assert ( From 19086465e8040c01c38d38eec5c53f966f0dad8b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 29 Jun 2024 17:21:25 +0900 Subject: [PATCH 035/348] Fix fp16 mixed precision, model is in bf16 without full_bf16 --- README.md | 11 +++++++-- library/sd3_train_utils.py | 10 +++++---- library/sd3_utils.py | 46 +++++++++++++++++++++++++++++++++----- sd3_train.py | 9 +++++--- 4 files changed, 61 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 34aa2bb2f..3eed636c5 100644 --- a/README.md +++ b/README.md @@ -4,21 +4,28 @@ This repository contains training, generation and utility scripts for Stable Dif SD3 training is done with `sd3_train.py`. +__Jun 29, 2024__: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result). + +`fp16` and `bf16` are available for mixed precision training. We are not sure which is better. + `optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently. `clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. -t5xxl doesn't seem to work with `fp16`, so use`bf16` or `fp32`. +t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. +`text_encoder_batch_size` is added experimentally for caching faster. + ```toml -learning_rate = 1e-5 # seems to be too high +learning_rate = 1e-6 # seems to depend on the batch size optimizer_type = "adafactor" optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] cache_text_encoder_outputs = true cache_text_encoder_outputs_to_disk = true vae_batch_size = 1 +text_encoder_batch_size = 4 cache_latents = true cache_latents_to_disk = true ``` diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 70c83c0ba..c8d52e1c8 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -28,14 +28,14 @@ from .sdxl_train_util import match_mixed_precision -def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype) -> Tuple[ +def load_target_model(args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype) -> Tuple[ sd3_models.MMDiT, Optional[sd3_models.SDClipModel], Optional[sd3_models.SDXLClipG], Optional[sd3_models.T5XXLModel], sd3_models.SDVAE, ]: - model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 + model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16, None or fp16/bf16 for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: @@ -49,13 +49,15 @@ def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device, args.vae, attn_mode, accelerator.device if args.lowram else "cpu", - weight_dtype, + model_dtype, args.disable_mmap_load_safetensors, + clip_dtype, t5xxl_device, t5xxl_dtype, + vae_dtype, ) - # work on low-ram device + # work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device if args.lowram: if clip_l is not None: clip_l.to(accelerator.device) diff --git a/library/sd3_utils.py b/library/sd3_utils.py index c2c914123..45b49b04b 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -28,11 +28,41 @@ def load_models( vae_path: str, attn_mode: str, device: Union[str, torch.device], - weight_dtype: torch.dtype, + default_dtype: Optional[Union[str, torch.dtype]] = None, disable_mmap: bool = False, - t5xxl_device: Optional[str] = None, - t5xxl_dtype: Optional[str] = None, + clip_dtype: Optional[Union[str, torch.dtype]] = None, + t5xxl_device: Optional[Union[str, torch.device]] = None, + t5xxl_dtype: Optional[Union[str, torch.dtype]] = None, + vae_dtype: Optional[Union[str, torch.dtype]] = None, ): + """ + Load SD3 models from checkpoint files. + + Args: + ckpt_path: Path to the SD3 checkpoint file. + clip_l_path: Path to the clip_l checkpoint file. + clip_g_path: Path to the clip_g checkpoint file. + t5xxl_path: Path to the t5xxl checkpoint file. + vae_path: Path to the VAE checkpoint file. + attn_mode: Attention mode for MMDiT model. + device: Device for MMDiT model. + default_dtype: Default dtype for each model. In training, it's usually None. None means using float32. + disable_mmap: Disable memory mapping when loading state dict. + clip_dtype: Dtype for Clip models, or None to use default dtype. + t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device. + t5xxl_dtype: Dtype for T5XXL model, or None to use default dtype. + vae_dtype: Dtype for VAE model, or None to use default dtype. + + Returns: + Tuple of MMDiT, ClipL, ClipG, T5XXL, and VAE models. + """ + + # In SD1/2 and SDXL, the model is created with empty weights and then loaded with state dict. + # However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict. + # Therefore, we need clip_dtype and t5xxl_dtype. + + # default_dtype is used for full_fp16/full_bf16 training. + def load_state_dict(path: str, dvc: Union[str, torch.device] = device): if disable_mmap: return safetensors.torch.load(open(path, "rb").read()) @@ -43,6 +73,9 @@ def load_state_dict(path: str, dvc: Union[str, torch.device] = device): return load_file(path) # prevent device invalid Error t5xxl_device = t5xxl_device or device + clip_dtype = clip_dtype or default_dtype or torch.float32 + t5xxl_dtype = t5xxl_dtype or default_dtype or torch.float32 + vae_dtype = vae_dtype or default_dtype or torch.float32 logger.info(f"Loading SD3 models from {ckpt_path}...") state_dict = load_state_dict(ckpt_path) @@ -124,7 +157,7 @@ def load_state_dict(path: str, dvc: Union[str, torch.device] = device): mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) logger.info("Loading state dict...") - info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) + info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, default_dtype) logger.info(f"Loaded MMDiT: {info}") # load ClipG and ClipL @@ -132,7 +165,7 @@ def load_state_dict(path: str, dvc: Union[str, torch.device] = device): clip_l = None else: logger.info("Building ClipL") - clip_l = sd3_models.create_clip_l(device, weight_dtype, clip_l_sd) + clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) logger.info("Loading state dict...") info = clip_l.load_state_dict(clip_l_sd) logger.info(f"Loaded ClipL: {info}") @@ -142,7 +175,7 @@ def load_state_dict(path: str, dvc: Union[str, torch.device] = device): clip_g = None else: logger.info("Building ClipG") - clip_g = sd3_models.create_clip_g(device, weight_dtype, clip_g_sd) + clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) logger.info("Loading state dict...") info = clip_g.load_state_dict(clip_g_sd) logger.info(f"Loaded ClipG: {info}") @@ -165,6 +198,7 @@ def load_state_dict(path: str, dvc: Union[str, torch.device] = device): logger.info("Loading state dict...") info = vae.load_state_dict(vae_sd) logger.info(f"Loaded VAE: {info}") + vae.to(device=device, dtype=vae_dtype) return mmdit, clip_l, clip_g, t5xxl, vae diff --git a/sd3_train.py b/sd3_train.py index b6c932c4c..bd30cdc72 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -182,6 +182,8 @@ def train(args): raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device + clip_dtype = weight_dtype # if not args.train_text_encoder else None + # モデルを読み込む attn_mode = "xformers" if args.xformers else "torch" @@ -189,8 +191,9 @@ def train(args): attn_mode == "torch" ), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( - args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype + args, accelerator, attn_mode, None, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype ) assert clip_l is not None, "clip_l is required / clip_lは必須です" assert clip_g is not None, "clip_g is required / clip_gは必須です" @@ -868,8 +871,9 @@ def setup_parser() -> argparse.ArgumentParser: custom_train_functions.add_custom_train_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) - # TE training is disabled temporarily + # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + # TE training is disabled temporarily # parser.add_argument( # "--learning_rate_te1", # type=float, @@ -886,7 +890,6 @@ def setup_parser() -> argparse.ArgumentParser: # parser.add_argument( # "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" # ) - # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") # parser.add_argument( # "--no_half_vae", # action="store_true", From ea18d5ba6d856995d5c44be4b449b63ac66fe5db Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 29 Jun 2024 17:45:50 +0900 Subject: [PATCH 036/348] Fix to work full_bf16 and full_fp16. --- library/sd3_models.py | 8 ++++++++ library/sd3_utils.py | 14 ++++++-------- sd3_train.py | 20 ++++++++++---------- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index c19aec6aa..7041420cb 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -891,6 +891,14 @@ def __init__( def model_type(self): return "m" # only support medium + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + def enable_gradient_checkpointing(self): self.gradient_checkpointing = True for block in self.joint_blocks: diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 45b49b04b..9dc9e7967 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -28,7 +28,7 @@ def load_models( vae_path: str, attn_mode: str, device: Union[str, torch.device], - default_dtype: Optional[Union[str, torch.dtype]] = None, + weight_dtype: Optional[Union[str, torch.dtype]] = None, disable_mmap: bool = False, clip_dtype: Optional[Union[str, torch.dtype]] = None, t5xxl_device: Optional[Union[str, torch.device]] = None, @@ -46,7 +46,7 @@ def load_models( vae_path: Path to the VAE checkpoint file. attn_mode: Attention mode for MMDiT model. device: Device for MMDiT model. - default_dtype: Default dtype for each model. In training, it's usually None. None means using float32. + weight_dtype: Default dtype of weights for all models. This is weight dtype, so the model dtype may be different. disable_mmap: Disable memory mapping when loading state dict. clip_dtype: Dtype for Clip models, or None to use default dtype. t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device. @@ -61,8 +61,6 @@ def load_models( # However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict. # Therefore, we need clip_dtype and t5xxl_dtype. - # default_dtype is used for full_fp16/full_bf16 training. - def load_state_dict(path: str, dvc: Union[str, torch.device] = device): if disable_mmap: return safetensors.torch.load(open(path, "rb").read()) @@ -73,9 +71,9 @@ def load_state_dict(path: str, dvc: Union[str, torch.device] = device): return load_file(path) # prevent device invalid Error t5xxl_device = t5xxl_device or device - clip_dtype = clip_dtype or default_dtype or torch.float32 - t5xxl_dtype = t5xxl_dtype or default_dtype or torch.float32 - vae_dtype = vae_dtype or default_dtype or torch.float32 + clip_dtype = clip_dtype or weight_dtype or torch.float32 + t5xxl_dtype = t5xxl_dtype or weight_dtype or torch.float32 + vae_dtype = vae_dtype or weight_dtype or torch.float32 logger.info(f"Loading SD3 models from {ckpt_path}...") state_dict = load_state_dict(ckpt_path) @@ -157,7 +155,7 @@ def load_state_dict(path: str, dvc: Union[str, torch.device] = device): mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) logger.info("Loading state dict...") - info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, default_dtype) + info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) logger.info(f"Loaded MMDiT: {info}") # load ClipG and ClipL diff --git a/sd3_train.py b/sd3_train.py index bd30cdc72..de763ac6d 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -182,7 +182,7 @@ def train(args): raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device - clip_dtype = weight_dtype # if not args.train_text_encoder else None + clip_dtype = weight_dtype # if not args.train_text_encoder else None # モデルを読み込む attn_mode = "xformers" if args.xformers else "torch" @@ -193,7 +193,7 @@ def train(args): # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( - args, accelerator, attn_mode, None, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype + args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype ) assert clip_l is not None, "clip_l is required / clip_lは必須です" assert clip_g is not None, "clip_g is required / clip_gは必須です" @@ -769,10 +769,10 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): epoch, num_train_epochs, global_step, - clip_l if args.save_clip else None, - clip_g if args.save_clip else None, - t5xxl if args.save_t5xxl else None, - mmdit, + accelerator.unwrap_model(clip_l) if args.save_clip else None, + accelerator.unwrap_model(clip_g) if args.save_clip else None, + accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None, + accelerator.unwrap_model(mmdit), vae, ) @@ -807,10 +807,10 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): epoch, num_train_epochs, global_step, - clip_l if args.save_clip else None, - clip_g if args.save_clip else None, - t5xxl if args.save_t5xxl else None, - mmdit, + accelerator.unwrap_model(clip_l) if args.save_clip else None, + accelerator.unwrap_model(clip_g) if args.save_clip else None, + accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None, + accelerator.unwrap_model(mmdit), vae, ) From 50e3d6247459c9f59facaef42e03b34cd8d6287d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 8 Jul 2024 19:46:23 +0900 Subject: [PATCH 037/348] fix to work T5XXL with fp16 --- library/sd3_models.py | 144 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 138 insertions(+), 6 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 7041420cb..e4c0790d9 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1124,7 +1124,12 @@ def __init__(self, in_channels, dtype=torch.float32, device=None): self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) def forward(self, x): + org_dtype = x.dtype + if x.dtype == torch.bfloat16: + x = x.to(torch.float32) x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if x.dtype != org_dtype: + x = x.to(org_dtype) x = self.conv(x) return x @@ -1263,11 +1268,11 @@ def device(self): def dtype(self): return next(self.parameters()).dtype - @torch.autocast("cuda", dtype=torch.float16) + # @torch.autocast("cuda", dtype=torch.float16) def decode(self, latent): return self.decoder(latent) - @torch.autocast("cuda", dtype=torch.float16) + # @torch.autocast("cuda", dtype=torch.float16) def encode(self, image): hidden = self.encoder(image) mean, logvar = torch.chunk(hidden, 2, dim=1) @@ -1630,10 +1635,25 @@ def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) self.variance_epsilon = eps - def forward(self, x): - variance = x.pow(2).mean(-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) - return self.weight.to(device=x.device, dtype=x.dtype) * x + # def forward(self, x): + # variance = x.pow(2).mean(-1, keepdim=True) + # x = x * torch.rsqrt(variance + self.variance_epsilon) + # return self.weight.to(device=x.device, dtype=x.dtype) * x + + # copy from transformers' T5LayerNorm + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states class T5DenseGatedActDense(torch.nn.Module): @@ -1775,7 +1795,27 @@ def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_b def forward(self, x, past_bias=None): x, past_bias = self.layer[0](x, past_bias) + + # copy from transformers' T5Block + # clamp inf values to enable fp16 training + if x.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(x).any(), + torch.finfo(x.dtype).max - 1000, + torch.finfo(x.dtype).max, + ) + x = torch.clamp(x, min=-clamp_value, max=clamp_value) + x = self.layer[-1](x) + # clamp inf values to enable fp16 training + if x.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(x).any(), + torch.finfo(x.dtype).max - 1000, + torch.finfo(x.dtype).max, + ) + x = torch.clamp(x, min=-clamp_value, max=clamp_value) + return x, past_bias @@ -1896,4 +1936,96 @@ def create_t5xxl(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[st return t5 +""" + # snippet for using the T5 model from transformers + + from transformers import T5EncoderModel, T5Config + import accelerate + import json + + T5_CONFIG_JSON = "" +{ + "architectures": [ + "T5EncoderModel" + ], + "classifier_dropout": 0.0, + "d_ff": 10240, + "d_kv": 64, + "d_model": 4096, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 64, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.41.2", + "use_cache": true, + "vocab_size": 32128 +} +"" + config = json.loads(T5_CONFIG_JSON) + config = T5Config(**config) + + # model = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3") + # print(model.config) + # # model(**load_model.config) + + # with accelerate.init_empty_weights(): + model = T5EncoderModel._from_config(config) # , torch_dtype=dtype) + for key in list(state_dict.keys()): + if key.startswith("transformer."): + new_key = key[len("transformer.") :] + state_dict[new_key] = state_dict.pop(key) + + info = model.load_state_dict(state_dict) + print(info) + model.set_attn_mode = lambda x: None + # model.to("cpu") + + _self = model + + def enc(list_of_token_weight_pairs): + has_batch = isinstance(list_of_token_weight_pairs[0][0], list) + + if has_batch: + list_of_tokens = [] + for pairs in list_of_token_weight_pairs: + tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] + list_of_tokens.append(tokens) + else: + list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] + + list_of_tokens = np.array(list_of_tokens) + list_of_tokens = torch.from_numpy(list_of_tokens).to("cuda", dtype=torch.long) + out = _self(list_of_tokens) + pooled = None + if has_batch: + return out, pooled + else: + if pooled is not None: + first_pooled = pooled[0:1] + else: + first_pooled = pooled + return out[0], first_pooled + # output = [out[0:1]] + # return torch.cat(output, dim=-2), first_pooled + + model.encode_token_weights = enc + + return model +""" + # endregion From c9de7c4e9a3d02ab6f18f105c880a9ba88b667ab Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 8 Jul 2024 19:48:28 +0900 Subject: [PATCH 038/348] WIP: new latents caching --- library/sd3_train_utils.py | 94 +++++++++++++++++++++++- library/train_util.py | 147 ++++++++++++++++++++++++++++++++++++- sd3_train.py | 37 +++++++++- 3 files changed, 270 insertions(+), 8 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index c8d52e1c8..9309ee30c 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -1,7 +1,7 @@ import argparse import math import os -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch from safetensors.torch import save_file @@ -283,6 +283,98 @@ def sample_images(*args, **kwargs): return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) +class Sd3LatensCachingStrategy(train_util.LatentsCachingStrategy): + SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" + + def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + self.vae = vae + + def get_latents_npz_path(self, absolute_path: str): + return os.path.splitext(absolute_path)[0] + self.SD3_LATENTS_NPZ_SUFFIX + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H) + + try: + npz = np.load(npz_path) + if npz["latents"].shape[1:3] != expected_latents_size: + return False + + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + + if alpha_mask: + if "alpha_mask" not in npz: + return False + if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]): + return False + else: + if "alpha_mask" in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): + img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( + image_infos, alpha_mask, random_crop + ) + img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype) + + with torch.no_grad(): + latents = self.vae.encode(img_tensor).to("cpu") + if flip_aug: + img_tensor = torch.flip(img_tensor, dims=[3]) + with torch.no_grad(): + flipped_latents = self.vae.encode(img_tensor).to("cpu") + else: + flipped_latents = [None] * len(latents) + + for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks): + if self.cache_to_disk: + # save_latents_to_disk( + # info.latents_npz, + # latent, + # info.latents_original_size, + # info.latents_crop_ltrb, + # flipped_latent, + # alpha_mask, + # ) + kwargs = {} + if flipped_latent is not None: + kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy() + if alpha_mask is not None: + kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() + np.savez( + info.latents_npz, + latents=latents.float().cpu().numpy(), + original_size=np.array(original_sizes), + crop_ltrb=np.array(crop_ltrbs), + **kwargs, + ) + else: + info.latents = latent + if flip_aug: + info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask + + if not train_util.HIGH_VRAM: + clean_memory_on_device(self.vae.device) + + # region Diffusers diff --git a/library/train_util.py b/library/train_util.py index 96d32e3bc..8444827df 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -359,6 +359,30 @@ def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarra return self.color_aug if use_color_aug else None +class LatentsCachingStrategy: + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + def get_latents_npz_path(self, absolute_path: str): + raise NotImplementedError + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + raise NotImplementedError + + def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): + raise NotImplementedError + + class BaseSubset: def __init__( self, @@ -986,6 +1010,69 @@ def is_text_encoder_output_cacheable(self): ] ) + def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy): + r""" + a brand new method to cache latents. This method caches latents with caching strategy. + normal cache_latents method is used by default, but this method is used when caching strategy is specified. + """ + logger.info("caching latents with caching strategy.") + image_infos = list(self.image_data.values()) + + # sort by resolution + image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) + + # split by resolution + batches = [] + batch = [] + logger.info("checking cache validity...") + for info in tqdm(image_infos): + subset = self.image_to_subset[info.image_key] + + if info.latents_npz is not None: # fine tuning dataset + continue + + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path) + if not is_main_process: # prepare for multi-gpu, only store to info + continue + + cache_available = caching_strategy.is_disk_cached_latents_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) + if cache_available: # do not add to batch + continue + + # if last member of batch has different resolution, flush the batch + if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: + batches.append(batch) + batch = [] + + batch.append(info) + + # if number of data in batch is enough, flush the batch + if len(batch) >= caching_strategy.batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + # if cache to disk, don't cache latents in non-main process, set to info only + if caching_strategy.cache_to_disk and not is_main_process: + return + + if len(batches) == 0: + logger.info("no latents to cache") + return + + # iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded + logger.info("caching latents...") + for batch in tqdm(batches, smoothing=1, total=len(batches)): + # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching latents.") @@ -1086,7 +1173,7 @@ def cache_text_encoder_outputs_common( if batch_size is None: batch_size = self.batch_size - + image_infos = list(self.image_data.values()) logger.info("checking cache existence...") @@ -2207,6 +2294,11 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc logger.info(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) + def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.new_cache_latents(is_main_process, strategy) + def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True ): @@ -2550,6 +2642,51 @@ def trim_and_resize_if_required( return image, original_size, crop_ltrb +# for new_cache_latents +def load_images_and_masks_for_caching( + image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool +) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: + r""" + requires image_infos to have: [absolute_path or image], bucket_reso, resized_size + + returns: image_tensor, alpha_masks, original_sizes, crop_ltrbs + + image_tensor: torch.Tensor = torch.Size([B, 3, H, W]), ...], normalized to [-1, 1] + alpha_masks: List[np.ndarray] = [np.ndarray([H, W]), ...], normalized to [0, 1] + original_sizes: List[Tuple[int, int]] = [(W, H), ...] + crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...] + """ + images: List[torch.Tensor] = [] + alpha_masks: List[np.ndarray] = [] + original_sizes: List[Tuple[int, int]] = [] + crop_ltrbs: List[Tuple[int, int, int, int]] = [] + for info in image_infos: + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) + # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + + original_sizes.append(original_size) + crop_ltrbs.append(crop_ltrb) + + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + else: + alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] + else: + alpha_mask = None + alpha_masks.append(alpha_mask) + + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + images.append(image) + + img_tensor = torch.stack(images, dim=0) + return img_tensor, alpha_masks, original_sizes, crop_ltrbs + + def cache_batch_latents( vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool ) -> None: @@ -2661,7 +2798,7 @@ def cache_batch_text_encoder_outputs_sd3( ): # make input_ids for each text encoder l_tokens, g_tokens, t5_tokens = input_ids - + clip_l, clip_g, t5xxl = text_encoders with torch.no_grad(): b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens( @@ -2670,8 +2807,12 @@ def cache_batch_text_encoder_outputs_sd3( b_lg_out = b_lg_out.detach() b_t5_out = b_t5_out.detach() b_pool = b_pool.detach() - + for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): + # debug: NaN check + if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any(): + raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}") + if cache_to_disk: save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) else: diff --git a/sd3_train.py b/sd3_train.py index de763ac6d..c073ec0e2 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -204,11 +204,22 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible - with torch.no_grad(): - train_dataset_group.cache_latents( - vae_wrapper, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process, file_suffix="_sd3.npz" + + if not args.new_caching: + vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible + with torch.no_grad(): + train_dataset_group.cache_latents( + vae_wrapper, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + file_suffix="_sd3.npz", + ) + else: + strategy = sd3_train_utils.Sd3LatensCachingStrategy( + vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check ) + train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -699,6 +710,17 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + # debug: NaN check for all inputs + if torch.any(torch.isnan(noisy_model_input)): + accelerator.print("NaN found in noisy_model_input, replacing with zeros") + noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input) + if torch.any(torch.isnan(context)): + accelerator.print("NaN found in context, replacing with zeros") + context = torch.nan_to_num(context, 0, out=context) + if torch.any(torch.isnan(pool)): + accelerator.print("NaN found in pool, replacing with zeros") + pool = torch.nan_to_num(pool, 0, out=pool) + # call model with accelerator.autocast(): model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool) @@ -908,6 +930,13 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", ) + + parser.add_argument("--new_caching", action="store_true", help="use new caching method / 新しいキャッシング方法を使う") + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="skip latents validity check / latentsの正当性チェックをスキップする", + ) return parser From 3ea4fce5e0f3d1a9c2718d77f49c3b304d25e565 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 8 Jul 2024 22:04:43 +0900 Subject: [PATCH 039/348] load models one by one --- library/sd3_train_utils.py | 56 ++++++------ library/sd3_utils.py | 169 +++++++++++++++++++++++++++++++++++++ sd3_train.py | 58 +++++++++---- 3 files changed, 236 insertions(+), 47 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 9309ee30c..98ee66bf8 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -1,19 +1,17 @@ import argparse import math import os -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from safetensors.torch import save_file +from accelerate import Accelerator from library import sd3_models, sd3_utils, train_util from library.device_utils import init_ipex, clean_memory_on_device init_ipex() -from accelerate import init_empty_weights -from tqdm import tqdm - # from transformers import CLIPTokenizer # from library import model_util # , sdxl_model_util, train_util, sdxl_original_unet @@ -28,50 +26,48 @@ from .sdxl_train_util import match_mixed_precision -def load_target_model(args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype) -> Tuple[ +def load_target_model( + model_type: str, + args: argparse.Namespace, + state_dict: dict, + accelerator: Accelerator, + attn_mode: str, + model_dtype: Optional[torch.dtype], + device: Optional[torch.device], +) -> Union[ sd3_models.MMDiT, Optional[sd3_models.SDClipModel], Optional[sd3_models.SDXLClipG], Optional[sd3_models.T5XXLModel], sd3_models.SDVAE, ]: - model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16, None or fp16/bf16 + loading_device = device if device is not None else (accelerator.device if args.lowram else "cpu") for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") - mmdit, clip_l, clip_g, t5xxl, vae = sd3_utils.load_models( - args.pretrained_model_name_or_path, - args.clip_l, - args.clip_g, - args.t5xxl, - args.vae, - attn_mode, - accelerator.device if args.lowram else "cpu", - model_dtype, - args.disable_mmap_load_safetensors, - clip_dtype, - t5xxl_device, - t5xxl_dtype, - vae_dtype, - ) + if model_type == "mmdit": + model = sd3_utils.load_mmdit(state_dict, attn_mode, model_dtype, loading_device) + elif model_type == "clip_l": + model = sd3_utils.load_clip_l(state_dict, args.clip_l, attn_mode, model_dtype, loading_device) + elif model_type == "clip_g": + model = sd3_utils.load_clip_g(state_dict, args.clip_g, attn_mode, model_dtype, loading_device) + elif model_type == "t5xxl": + model = sd3_utils.load_t5xxl(state_dict, args.t5xxl, attn_mode, model_dtype, loading_device) + elif model_type == "vae": + model = sd3_utils.load_vae(state_dict, args.vae, model_dtype, loading_device) + else: + raise ValueError(f"Unknown model type: {model_type}") # work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device if args.lowram: - if clip_l is not None: - clip_l.to(accelerator.device) - if clip_g is not None: - clip_g.to(accelerator.device) - if t5xxl is not None: - t5xxl.to(accelerator.device) - vae.to(accelerator.device) - mmdit.to(accelerator.device) + model = model.to(accelerator.device) clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() - return mmdit, clip_l, clip_g, t5xxl, vae + return model def save_models( diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 9dc9e7967..16f80c60d 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -20,6 +20,175 @@ # region models +def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False): + if disable_mmap: + return safetensors.torch.load(open(path, "rb").read()) + else: + try: + return load_file(path, device=dvc) + except: + return load_file(path) # prevent device invalid Error + + +def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]): + mmdit_sd = {} + + mmdit_prefix = "model.diffusion_model." + for k in list(state_dict.keys()): + if k.startswith(mmdit_prefix): + mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k) + + # load MMDiT + logger.info("Building MMDit") + with init_empty_weights(): + mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) + + logger.info("Loading state dict...") + info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype) + logger.info(f"Loaded MMDiT: {info}") + return mmdit + + +def load_clip_l( + state_dict: Dict, + clip_l_path: Optional[str], + attn_mode: str, + clip_dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, +): + clip_l_sd = None + if clip_l_path: + logger.info(f"Loading clip_l from {clip_l_path}...") + clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap) + for key in list(clip_l_sd.keys()): + clip_l_sd["transformer." + key] = clip_l_sd.pop(key) + else: + if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_l: remove prefix "text_encoders.clip_l." + logger.info("clip_l is included in the checkpoint") + clip_l_sd = {} + prefix = "text_encoders.clip_l." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + + if clip_l_sd is None: + clip_l = None + else: + logger.info("Building ClipL") + clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) + logger.info("Loading state dict...") + info = clip_l.load_state_dict(clip_l_sd) + logger.info(f"Loaded ClipL: {info}") + clip_l.set_attn_mode(attn_mode) + return clip_l + + +def load_clip_g( + state_dict: Dict, + clip_g_path: Optional[str], + attn_mode: str, + clip_dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, +): + clip_g_sd = None + if clip_g_path: + logger.info(f"Loading clip_g from {clip_g_path}...") + clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap) + for key in list(clip_g_sd.keys()): + clip_g_sd["transformer." + key] = clip_g_sd.pop(key) + else: + if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_g: remove prefix "text_encoders.clip_g." + logger.info("clip_g is included in the checkpoint") + clip_g_sd = {} + prefix = "text_encoders.clip_g." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + + if clip_g_sd is None: + clip_g = None + else: + logger.info("Building ClipG") + clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) + logger.info("Loading state dict...") + info = clip_g.load_state_dict(clip_g_sd) + logger.info(f"Loaded ClipG: {info}") + clip_g.set_attn_mode(attn_mode) + return clip_g + + +def load_t5xxl( + state_dict: Dict, + t5xxl_path: Optional[str], + attn_mode: str, + dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, +): + t5xxl_sd = None + if t5xxl_path: + logger.info(f"Loading t5xxl from {t5xxl_path}...") + t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap) + for key in list(t5xxl_sd.keys()): + t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) + else: + if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: + # found t5xxl: remove prefix "text_encoders.t5xxl." + logger.info("t5xxl is included in the checkpoint") + t5xxl_sd = {} + prefix = "text_encoders.t5xxl." + for k in list(state_dict.keys()): + if k.startswith(prefix): + t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + + if t5xxl_sd is None: + t5xxl = None + else: + logger.info("Building T5XXL") + + # workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device + t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd) + t5xxl.to(dtype=dtype) + + logger.info("Loading state dict...") + info = t5xxl.load_state_dict(t5xxl_sd) + logger.info(f"Loaded T5XXL: {info}") + t5xxl.set_attn_mode(attn_mode) + return t5xxl + + +def load_vae( + state_dict: Dict, + vae_path: Optional[str], + vae_dtype: Optional[Union[str, torch.dtype]], + device: Optional[Union[str, torch.device]], + disable_mmap: bool = False, +): + vae_sd = {} + if vae_path: + logger.info(f"Loading VAE from {vae_path}...") + vae_sd = load_safetensors(vae_path, device, disable_mmap) + else: + # remove prefix "first_stage_model." + vae_sd = {} + vae_prefix = "first_stage_model." + for k in list(state_dict.keys()): + if k.startswith(vae_prefix): + vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) + + logger.info("Building VAE") + vae = sd3_models.SDVAE() + logger.info("Loading state dict...") + info = vae.load_state_dict(vae_sd) + logger.info(f"Loaded VAE: {info}") + vae.to(device=device, dtype=vae_dtype) + return vae + + def load_models( ckpt_path: str, clip_l_path: str, diff --git a/sd3_train.py b/sd3_train.py index c073ec0e2..10cc5d57f 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -13,12 +13,12 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device - init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils +from library.sdxl_train_util import match_mixed_precision # , sdxl_model_util @@ -189,18 +189,19 @@ def train(args): assert ( attn_mode == "torch" - ), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" - # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. - mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( - args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype + # SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying. + logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}") + device_to_load = accelerator.device if args.lowram else "cpu" + sd3_state_dict = sd3_utils.load_safetensors( + args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors ) - assert clip_l is not None, "clip_l is required / clip_lは必須です" - assert clip_g is not None, "clip_g is required / clip_gは必須です" - # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) - # 学習を準備する + # load VAE for caching latents + vae: sd3_models.SDVAE = None if cache_latents: + vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() @@ -220,15 +221,25 @@ def train(args): vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check ) train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy) - vae.to("cpu") + vae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + # load clip_l, clip_g, t5xxl for caching text encoder outputs + # # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. + # mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( + # args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype + # ) + clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + assert clip_l is not None, "clip_l is required / clip_lは必須です" + assert clip_g is not None, "clip_g is required / clip_gは必須です" + + t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load) + # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) + # 学習を準備する:モデルを適切な状態にする - if args.gradient_checkpointing: - mmdit.enable_gradient_checkpointing() - train_mmdit = args.learning_rate != 0 train_clip_l = False train_clip_g = False train_t5xxl = False @@ -280,17 +291,30 @@ def train(args): accelerator.is_main_process, args.text_encoder_batch_size, ) + + # TODO we can delete text encoders after caching accelerator.wait_for_everyone() + # load MMDIT + # if full_fp16/bf16, model_dtype is casted to fp16/bf16. If not, model_dtype is None (float32). + # by loading with model_dtype, we can reduce memory usage. + model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) + mmdit = sd3_train_utils.load_target_model("mmdit", args, sd3_state_dict, accelerator, attn_mode, model_dtype, device_to_load) + if args.gradient_checkpointing: + mmdit.enable_gradient_checkpointing() + + train_mmdit = args.learning_rate != 0 + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdie will not be prepared + if not cache_latents: + # load VAE here if not cached + vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) vae.requires_grad_(False) vae.eval() vae.to(accelerator.device, dtype=vae_dtype) - mmdit.requires_grad_(train_mmdit) - if not train_mmdit: - mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared - training_models = [] params_to_optimize = [] # if train_unet: From 9dc7997803d70c718969526352e88908e827f091 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 9 Jul 2024 20:37:00 +0900 Subject: [PATCH 040/348] fix typo --- library/sd3_models.py | 2 +- library/sd3_train_utils.py | 2 +- sd3_train.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index e4c0790d9..a1ff1e75a 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1643,7 +1643,7 @@ def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): # copy from transformers' T5LayerNorm def forward(self, hidden_states): # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # half-precision inputs is done in fp32 variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 98ee66bf8..660342108 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -279,7 +279,7 @@ def sample_images(*args, **kwargs): return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) -class Sd3LatensCachingStrategy(train_util.LatentsCachingStrategy): +class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: diff --git a/sd3_train.py b/sd3_train.py index 10cc5d57f..30d994c78 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -217,7 +217,7 @@ def train(args): file_suffix="_sd3.npz", ) else: - strategy = sd3_train_utils.Sd3LatensCachingStrategy( + strategy = sd3_train_utils.Sd3LatentsCachingStrategy( vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check ) train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy) From 3d402927efb2d396f8f33fe6a1747e43f7a5f0f3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 9 Jul 2024 23:15:38 +0900 Subject: [PATCH 041/348] WIP: update new latents caching --- library/sd3_train_utils.py | 49 +++++++++++++++++++++++++------------- library/train_util.py | 39 ++++++++++++++++++++++++++---- sd3_train.py | 15 ++++++++---- 3 files changed, 77 insertions(+), 26 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 660342108..245912199 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -1,4 +1,5 @@ import argparse +import glob import math import os from typing import List, Optional, Tuple, Union @@ -282,12 +283,26 @@ def sample_images(*args, **kwargs): class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" - def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + self.vae = None + + def set_vae(self, vae: sd3_models.SDVAE): self.vae = vae - def get_latents_npz_path(self, absolute_path: str): - return os.path.splitext(absolute_path)[0] + self.SD3_LATENTS_NPZ_SUFFIX + def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) + if len(npz_file) == 0: + return None, None + w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): if not self.cache_to_disk: @@ -331,24 +346,24 @@ def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype) with torch.no_grad(): - latents = self.vae.encode(img_tensor).to("cpu") + latents_tensors = self.vae.encode(img_tensor).to("cpu") if flip_aug: img_tensor = torch.flip(img_tensor, dims=[3]) with torch.no_grad(): flipped_latents = self.vae.encode(img_tensor).to("cpu") else: - flipped_latents = [None] * len(latents) + flipped_latents = [None] * len(latents_tensors) + + # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): + for i in range(len(image_infos)): + info = image_infos[i] + latents = latents_tensors[i] + flipped_latent = flipped_latents[i] + alpha_mask = alpha_masks[i] + original_size = original_sizes[i] + crop_ltrb = crop_ltrbs[i] - for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks): if self.cache_to_disk: - # save_latents_to_disk( - # info.latents_npz, - # latent, - # info.latents_original_size, - # info.latents_crop_ltrb, - # flipped_latent, - # alpha_mask, - # ) kwargs = {} if flipped_latent is not None: kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy() @@ -357,12 +372,12 @@ def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: np.savez( info.latents_npz, latents=latents.float().cpu().numpy(), - original_size=np.array(original_sizes), - crop_ltrb=np.array(crop_ltrbs), + original_size=np.array(original_size), + crop_ltrb=np.array(crop_ltrb), **kwargs, ) else: - info.latents = latent + info.latents = latents if flip_aug: info.latents_flipped = flipped_latent info.alpha_mask = alpha_mask diff --git a/library/train_util.py b/library/train_util.py index 8444827df..9db226ea8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -360,11 +360,23 @@ def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarra class LatentsCachingStrategy: + _strategy = None # strategy instance: actual strategy class + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: self._cache_to_disk = cache_to_disk self._batch_size = batch_size self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: + return cls._strategy + @property def cache_to_disk(self): return self._cache_to_disk @@ -373,10 +385,15 @@ def cache_to_disk(self): def batch_size(self): return self._batch_size - def get_latents_npz_path(self, absolute_path: str): + def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + raise NotImplementedError + + def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str: raise NotImplementedError - def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + def is_disk_cached_latents_expected( + self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + ) -> bool: raise NotImplementedError def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -1034,7 +1051,7 @@ def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCach # check disk cache exists and size of latents if caching_strategy.cache_to_disk: # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix - info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path) + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) if not is_main_process: # prepare for multi-gpu, only store to info continue @@ -1730,6 +1747,18 @@ def load_dreambooth_dir(subset: DreamBoothSubset): img_paths = glob_images(subset.image_dir, "*") sizes = [None] * len(img_paths) + # new caching: get image size from cache files + strategy = LatentsCachingStrategy.get_strategy() + if strategy is not None: + logger.info("get image size from cache files") + size_set_count = 0 + for i, img_path in enumerate(tqdm(img_paths)): + w, h = strategy.get_image_size_from_image_absolute_path(img_path) + if w is not None and h is not None: + sizes[i] = [w, h] + size_set_count += 1 + logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") if use_cached_info_for_subset: @@ -2807,12 +2836,12 @@ def cache_batch_text_encoder_outputs_sd3( b_lg_out = b_lg_out.detach() b_t5_out = b_t5_out.detach() b_pool = b_pool.detach() - + for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): # debug: NaN check if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any(): raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}") - + if cache_to_disk: save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) else: diff --git a/sd3_train.py b/sd3_train.py index 30d994c78..e2f622e47 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -91,6 +91,15 @@ def train(args): # load tokenizer sd3_tokenizer = sd3_models.SD3Tokenizer() + # prepare caching strategy + if args.new_caching: + latents_caching_strategy = sd3_train_utils.Sd3LatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + ) + else: + latents_caching_strategy = None + train_util.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + # データセットを準備する if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) @@ -217,10 +226,8 @@ def train(args): file_suffix="_sd3.npz", ) else: - strategy = sd3_train_utils.Sd3LatentsCachingStrategy( - vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check - ) - train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy) + latents_caching_strategy.set_vae(vae) + train_dataset_group.new_cache_latents(accelerator.is_main_process, latents_caching_strategy) vae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) From 6f0e235f2cb9a9829bc12280c29e12c0ae66c88f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 11 Jul 2024 08:00:45 +0900 Subject: [PATCH 042/348] Fix shift value in SD3 inference. --- sd3_minimal_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index 7f5f28cea..ffa0d46de 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -64,7 +64,7 @@ def do_sample( device: str, ): if initial_latent is None: - # latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 + # latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 # this seems to be a bug in the original code. thanks to furusu for pointing it out latent = torch.zeros(1, 16, height // 8, width // 8, device=device) else: latent = initial_latent @@ -73,7 +73,7 @@ def do_sample( noise = get_noise(seed, latent).to(device) - model_sampling = sd3_utils.ModelSamplingDiscreteFlow() + model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 sigmas = get_sigmas(model_sampling, steps).to(device) # sigmas = sigmas[int(steps * (1 - denoise)) :] # do not support i2i From b8896aad400222c8c4441b217fda0f9bb0807ffd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 11 Jul 2024 08:01:23 +0900 Subject: [PATCH 043/348] update README --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3eed636c5..5d4f9621d 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,9 @@ This repository contains training, generation and utility scripts for Stable Dif SD3 training is done with `sd3_train.py`. -__Jun 29, 2024__: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result). +__Jul 11, 2024__: Fixed to work t5xxl with `fp16`. If you change the dtype to `fp16` for t5xxl, please remove existing latents cache files (`*_sd3.npz`). The shift in `sd3_minimum_inference.py` is fixed to 3.0. Thanks to araleza! + +Jun 29, 2024: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result). `fp16` and `bf16` are available for mixed precision training. We are not sure which is better. @@ -12,7 +14,7 @@ __Jun 29, 2024__: Fixed mixed precision training with fp16 is not working. Fixed `clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. -t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. +~~t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. ~~ t5xxl works with `fp16` now. There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. From 082f13658bdbaed872ede6c0a7a75ab1a5f3712d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 12 Jul 2024 21:28:01 +0900 Subject: [PATCH 044/348] reduce peak GPU memory usage before training --- library/sd3_models.py | 2 +- library/train_util.py | 1 + sd3_train.py | 44 +++++++++++++++++++++---------------------- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index a1ff1e75a..ec8e1bbdd 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -471,7 +471,7 @@ def __init__( num_heads: int = 8, qkv_bias: bool = False, pre_only: bool = False, - qk_norm: str = None, + qk_norm: Optional[str] = None, ): super().__init__() self.num_heads = num_heads diff --git a/library/train_util.py b/library/train_util.py index 9db226ea8..7af0070e1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2410,6 +2410,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) +# TODO update to use CachingStrategy def load_latents_from_disk( npz_path, ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: diff --git a/sd3_train.py b/sd3_train.py index e2f622e47..f34e47124 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -458,6 +458,28 @@ def train(args): # text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) # text_encoder1.text_model.final_layer_norm.requires_grad_(False) + # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + clip_l.to("cpu", dtype=torch.float32) + clip_g.to("cpu", dtype=torch.float32) + if t5xxl is not None: + t5xxl.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + # TODO support CPU for text encoders + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(accelerator.device) + + # TODO cache sample prompt's embeddings to free text encoder's memory + if args.cache_text_encoder_outputs: + if not args.save_t5xxl: + t5xxl = None # free memory + clean_memory_on_device(accelerator.device) + if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model( args, @@ -482,28 +504,6 @@ def train(args): # text_encoder2 = accelerator.prepare(text_encoder2) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) - # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する - if args.cache_text_encoder_outputs: - # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 - clip_l.to("cpu", dtype=torch.float32) - clip_g.to("cpu", dtype=torch.float32) - if t5xxl is not None: - t5xxl.to("cpu", dtype=torch.float32) - clean_memory_on_device(accelerator.device) - else: - # make sure Text Encoders are on GPU - # TODO support CPU for text encoders - clip_l.to(accelerator.device) - clip_g.to(accelerator.device) - if t5xxl is not None: - t5xxl.to(accelerator.device) - - # TODO cache sample prompt's embeddings to free text encoder's memory - if args.cache_text_encoder_outputs: - if not args.save_t5xxl: - t5xxl = None # free memory - clean_memory_on_device(accelerator.device) - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. From 41dee60383a3b88859b80929a2c0d94b12c42068 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 27 Jul 2024 13:50:05 +0900 Subject: [PATCH 045/348] Refactor caching mechanism for latents and text encoder outputs, etc. --- README.md | 21 +- fine_tune.py | 54 +++- library/config_util.py | 2 - library/sd3_models.py | 47 +++- library/sd3_train_utils.py | 105 ------- library/sd3_utils.py | 1 + library/sdxl_train_util.py | 2 +- library/strategy_base.py | 328 ++++++++++++++++++++++ library/strategy_sd.py | 139 ++++++++++ library/strategy_sd3.py | 229 ++++++++++++++++ library/strategy_sdxl.py | 247 +++++++++++++++++ library/train_util.py | 451 +++++++++++++++---------------- sd3_minimal_inference.py | 22 +- sd3_train.py | 272 +++++++++++-------- sdxl_train.py | 108 ++++---- sdxl_train_control_net_lllite.py | 99 ++++--- sdxl_train_network.py | 48 +++- sdxl_train_textual_inversion.py | 49 ++-- train_db.py | 67 +++-- train_network.py | 122 ++++++--- train_textual_inversion.py | 118 ++++---- 21 files changed, 1792 insertions(+), 739 deletions(-) create mode 100644 library/strategy_base.py create mode 100644 library/strategy_sd.py create mode 100644 library/strategy_sd3.py create mode 100644 library/strategy_sdxl.py diff --git a/README.md b/README.md index 5d4f9621d..d406fecde 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,16 @@ This repository contains training, generation and utility scripts for Stable Dif SD3 training is done with `sd3_train.py`. -__Jul 11, 2024__: Fixed to work t5xxl with `fp16`. If you change the dtype to `fp16` for t5xxl, please remove existing latents cache files (`*_sd3.npz`). The shift in `sd3_minimum_inference.py` is fixed to 3.0. Thanks to araleza! +__Jul 27, 2024__: +- Latents and text encoder outputs caching mechanism is refactored significantly. + - Existing cache files for SD3 need to be recreated. Please delete the previous cache files. + - With this change, dataset initialization is significantly faster, especially for large datasets. -Jun 29, 2024: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result). +- Architecture-dependent parts are extracted from the dataset (`train_util.py`). This is expected to make it easier to add future architectures. + +- Architecture-dependent parts including the cache mechanism for SD1/2/SDXL are also extracted. The basic operation of SD1/2/SDXL training on the sd3 branch has been confirmed, but there may be bugs. Please use the main or dev branch for SD1/2/SDXL training. + +--- `fp16` and `bf16` are available for mixed precision training. We are not sure which is better. @@ -14,7 +21,7 @@ Jun 29, 2024: Fixed mixed precision training with fp16 is not working. Fixed the `clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. -~~t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. ~~ t5xxl works with `fp16` now. +t5xxl works with `fp16` now. There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. @@ -32,6 +39,14 @@ cache_latents = true cache_latents_to_disk = true ``` +__2024/7/27:__ + +Latents およびテキストエンコーダ出力のキャッシュの仕組みを大きくリファクタリングしました。SD3 用の既存のキャッシュファイルの再作成が必要になりますが、ご了承ください(以前のキャッシュファイルは削除してください)。これにより、特にデータセットの規模が大きい場合のデータセット初期化が大幅に高速化されます。 + +データセット (`train_util.py`) からアーキテクチャ依存の部分を切り出しました。これにより将来的なアーキテクチャ追加が容易になると期待しています。 + +SD1/2/SDXL のキャッシュ機構を含むアーキテクチャ依存の部分も切り出しました。sd3 ブランチの SD1/2/SDXL 学習について、基本的な動作は確認していますが、不具合があるかもしれません。SD1/2/SDXL の学習には main または dev ブランチをお使いください。 + --- [__Change History__](#change-history) is moved to the bottom of the page. diff --git a/fine_tune.py b/fine_tune.py index d865cd2de..c9102f6c0 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -10,7 +10,7 @@ from tqdm import tqdm import torch -from library import deepspeed_utils +from library import deepspeed_utils, strategy_base from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -39,6 +39,7 @@ scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, ) +import library.strategy_sd as strategy_sd def train(args): @@ -52,7 +53,15 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if cache_latents: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -81,10 +90,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -165,8 +174,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -192,6 +202,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): else: text_encoder.eval() + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if not cache_latents: vae.requires_grad_(False) vae.eval() @@ -214,7 +227,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print("prepare optimizer, data loader etc.") _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -317,7 +334,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): ) # For --sample_at_first - train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -342,8 +361,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning if args.weighted_captions: + # TODO move to strategy_sd.py encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, + tokenize_strategy.tokenizer, text_encoder, batch["captions"], accelerator.device, @@ -351,10 +371,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): clip_skip=args.clip_skip, ) else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder], [input_ids] + )[0] + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -409,7 +431,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): global_step += 1 train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) # 指定ステップごとにモデルを保存 @@ -472,7 +494,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): vae, ) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) is_main_process = accelerator.is_main_process if is_main_process: diff --git a/library/config_util.py b/library/config_util.py index 10b2457f3..f8cdfe60a 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -104,8 +104,6 @@ class ControlNetSubsetParams(BaseSubsetParams): @dataclass class BaseDatasetParams: - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None - max_token_length: int = None resolution: Optional[Tuple[int, int]] = None network_multiplier: float = 1.0 debug_dataset: bool = False diff --git a/library/sd3_models.py b/library/sd3_models.py index ec8e1bbdd..28378c73b 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -38,7 +38,7 @@ def __init__( サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。 Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings. """ - self.tokenizer = tokenizer + self.tokenizer: CLIPTokenizer = tokenizer self.max_length = max_length self.min_length = min_length empty = self.tokenizer("")["input_ids"] @@ -56,6 +56,19 @@ def __init__( self.inv_vocab = {v: k for k, v in vocab.items()} self.max_word_length = 8 + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + """ + Tokenize the text without weights. + """ + if type(text) == str: + text = [text] + batch_tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt") + # return tokens["input_ids"] + + pad_token = self.end_token if self.pad_with_end else 0 + for tokens in batch_tokens["input_ids"]: + assert tokens[0] == self.start_token, f"tokens[0]: {tokens[0]}, start_token: {self.start_token}" + def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None): """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" @@ -75,13 +88,14 @@ def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate for word in to_tokenize: batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]]) batch.append((self.end_token, 1.0)) + print(len(batch), self.max_length, self.min_length) if self.pad_to_max_length: batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) if self.min_length is not None and len(batch) < self.min_length: batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) # truncate to max_length - # print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}") + print(f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}") if truncate_to_max_length and len(batch) > self.max_length: batch = batch[: self.max_length] if truncate_length is not None and len(batch) > truncate_length: @@ -110,27 +124,38 @@ def __init__(self, tokenizer): class SD3Tokenizer: - def __init__(self, t5xxl=True): + def __init__(self, t5xxl=True, t5xxl_max_length: Optional[int] = 256): + if t5xxl_max_length is None: + t5xxl_max_length = 256 + # TODO cache tokenizer settings locally or hold them in the repo like ComfyUI clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) self.clip_g = SDXLClipGTokenizer(clip_tokenizer) + # self.clip_l = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + # self.clip_g = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") self.t5xxl = T5XXLTokenizer() if t5xxl else None # t5xxl has 99999999 max length, clip has 77 - self.model_max_length = self.clip_l.max_length # 77 + self.t5xxl_max_length = t5xxl_max_length def tokenize_with_weights(self, text: str): - # temporary truncate to max_length even for t5xxl return ( self.clip_l.tokenize_with_weights(text), self.clip_g.tokenize_with_weights(text), ( - self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length) + self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.t5xxl_max_length) if self.t5xxl is not None else None ), ) + def tokenize(self, text: str): + return ( + self.clip_l.tokenize(text), + self.clip_g.tokenize(text), + (self.t5xxl.tokenize(text) if self.t5xxl is not None else None), + ) + # endregion @@ -1474,7 +1499,10 @@ def encode_token_weights(self, list_of_token_weight_pairs): tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] list_of_tokens.append(tokens) else: - list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] + if isinstance(list_of_token_weight_pairs[0], torch.Tensor): + list_of_tokens = [list(list_of_token_weight_pairs[0])] + else: + list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] out, pooled = self(list_of_tokens) if has_batch: @@ -1614,9 +1642,9 @@ def set_attn_mode(self, mode): ### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl ################################################################################################# - +""" class T5XXLTokenizer(SDTokenizer): - """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + ""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"" def __init__(self): super().__init__( @@ -1627,6 +1655,7 @@ def __init__(self): max_length=99999999, min_length=77, ) +""" class T5LayerNorm(torch.nn.Module): diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 245912199..8f99d9474 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -280,111 +280,6 @@ def sample_images(*args, **kwargs): return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) -class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): - SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" - - def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: - super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) - self.vae = None - - def set_vae(self, vae: sd3_models.SDVAE): - self.vae = vae - - def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) - if len(npz_file) == 0: - return None, None - w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") - return int(w), int(h) - - def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: - return ( - os.path.splitext(absolute_path)[0] - + f"_{image_size[0]:04d}x{image_size[1]:04d}" - + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX - ) - - def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - if not self.cache_to_disk: - return False - if not os.path.exists(npz_path): - return False - if self.skip_disk_cache_validity_check: - return True - - expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H) - - try: - npz = np.load(npz_path) - if npz["latents"].shape[1:3] != expected_latents_size: - return False - - if flip_aug: - if "latents_flipped" not in npz: - return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: - return False - - if alpha_mask: - if "alpha_mask" not in npz: - return False - if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]): - return False - else: - if "alpha_mask" in npz: - return False - except Exception as e: - logger.error(f"Error loading file: {npz_path}") - raise e - - return True - - def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): - img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( - image_infos, alpha_mask, random_crop - ) - img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype) - - with torch.no_grad(): - latents_tensors = self.vae.encode(img_tensor).to("cpu") - if flip_aug: - img_tensor = torch.flip(img_tensor, dims=[3]) - with torch.no_grad(): - flipped_latents = self.vae.encode(img_tensor).to("cpu") - else: - flipped_latents = [None] * len(latents_tensors) - - # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): - for i in range(len(image_infos)): - info = image_infos[i] - latents = latents_tensors[i] - flipped_latent = flipped_latents[i] - alpha_mask = alpha_masks[i] - original_size = original_sizes[i] - crop_ltrb = crop_ltrbs[i] - - if self.cache_to_disk: - kwargs = {} - if flipped_latent is not None: - kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy() - if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() - np.savez( - info.latents_npz, - latents=latents.float().cpu().numpy(), - original_size=np.array(original_size), - crop_ltrb=np.array(crop_ltrb), - **kwargs, - ) - else: - info.latents = latents - if flip_aug: - info.latents_flipped = flipped_latent - info.alpha_mask = alpha_mask - - if not train_util.HIGH_VRAM: - clean_memory_on_device(self.vae.device) - # region Diffusers diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 16f80c60d..5849518fb 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -384,6 +384,7 @@ def get_cond( dtype: Optional[torch.dtype] = None, ): l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt) + print(t5_tokens) return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index b74bea91a..f009b5779 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -327,7 +327,7 @@ def diffusers_saver(out_dir): ) -def add_sdxl_training_arguments(parser: argparse.ArgumentParser): +def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True): parser.add_argument( "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" ) diff --git a/library/strategy_base.py b/library/strategy_base.py new file mode 100644 index 000000000..594cca5eb --- /dev/null +++ b/library/strategy_base.py @@ -0,0 +1,328 @@ +# base class for platform strategies. this file defines the interface for strategies + +import os +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection + + +# TODO remove circular import by moving ImageInfo to a separate file +# from library.train_util import ImageInfo + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class TokenizeStrategy: + _strategy = None # strategy instance: actual strategy class + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TokenizeStrategy"]: + return cls._strategy + + def _load_tokenizer( + self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None + ) -> Any: + tokenizer = None + if tokenizer_cache_dir: + local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2 + + if tokenizer is None: + tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder) + + if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) + + return tokenizer + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + raise NotImplementedError + + def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor: + """ + for SD1.5/2.0/SDXL + TODO support batch input + """ + if max_length is None: + max_length = tokenizer.model_max_length - 2 + + input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids + + if max_length > tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if tokenizer.pad_token_id == tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75) + ids_chunk = ( + input_ids[0].unsqueeze(0), + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 or SDXL + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): + ids_chunk = ( + input_ids[0].unsqueeze(0), # BOS + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id: + ids_chunk[-1] = tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == tokenizer.pad_token_id: + ids_chunk[1] = tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + return input_ids + + +class TextEncodingStrategy: + _strategy = None # strategy instance: actual strategy class + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TextEncodingStrategy"]: + return cls._strategy + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Encode tokens into embeddings and outputs. + :param tokens: list of token tensors for each TextModel + :return: list of output embeddings for each architecture + """ + raise NotImplementedError + + +class TextEncoderOutputsCachingStrategy: + _strategy = None # strategy instance: actual strategy class + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + ) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + self._is_partial = is_partial + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]: + return cls._strategy + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + @property + def is_partial(self): + return self._is_partial + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + raise NotImplementedError + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + raise NotImplementedError + + def is_disk_cached_outputs_expected(self, npz_path: str) -> bool: + raise NotImplementedError + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List + ): + raise NotImplementedError + + +class LatentsCachingStrategy: + # TODO commonize utillity functions to this class, such as npz handling etc. + + _strategy = None # strategy instance: actual strategy class + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: + return cls._strategy + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + raise NotImplementedError + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + raise NotImplementedError + + def is_disk_cached_latents_expected( + self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + ) -> bool: + raise NotImplementedError + + def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + raise NotImplementedError + + def _defualt_is_disk_cached_latents_expected( + self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + ): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + + try: + npz = np.load(npz_path) + if npz["latents"].shape[1:3] != expected_latents_size: + return False + + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + + if alpha_mask: + if "alpha_mask" not in npz: + return False + if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]): + return False + else: + if "alpha_mask" in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + # TODO remove circular dependency for ImageInfo + def _default_cache_batch_latents( + self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool + ): + """ + Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. + """ + from library import train_util # import here to avoid circular import + + img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( + image_infos, alpha_mask, random_crop + ) + img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype) + + with torch.no_grad(): + latents_tensors = encode_by_vae(img_tensor).to("cpu") + if flip_aug: + img_tensor = torch.flip(img_tensor, dims=[3]) + with torch.no_grad(): + flipped_latents = encode_by_vae(img_tensor).to("cpu") + else: + flipped_latents = [None] * len(latents_tensors) + + # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): + for i in range(len(image_infos)): + info = image_infos[i] + latents = latents_tensors[i] + flipped_latent = flipped_latents[i] + alpha_mask = alpha_masks[i] + original_size = original_sizes[i] + crop_ltrb = crop_ltrbs[i] + + if self.cache_to_disk: + self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask) + else: + info.latents_original_size = original_size + info.latents_crop_ltrb = crop_ltrb + info.latents = latents + if flip_aug: + info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask + + def load_latents_from_disk( + self, npz_path: str + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + npz = np.load(npz_path) + if "latents" not in npz: + raise ValueError(f"error: npz is old format. please re-generate {npz_path}") + + latents = npz["latents"] + original_size = npz["original_size"].tolist() + crop_ltrb = npz["crop_ltrb"].tolist() + flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None + alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + + def save_latents_to_disk( + self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None + ): + kwargs = {} + if flipped_latents_tensor is not None: + kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() + if alpha_mask is not None: + kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() + np.savez( + npz_path, + latents=latents_tensor.float().cpu().numpy(), + original_size=np.array(original_size), + crop_ltrb=np.array(crop_ltrb), + **kwargs, + ) diff --git a/library/strategy_sd.py b/library/strategy_sd.py new file mode 100644 index 000000000..105816145 --- /dev/null +++ b/library/strategy_sd.py @@ -0,0 +1,139 @@ +import glob +import os +from typing import Any, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTokenizer +from library import train_util +from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +TOKENIZER_ID = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ + + +class SdTokenizeStrategy(TokenizeStrategy): + def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: + """ + max_length does not include and (None, 75, 150, 225) + """ + logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer") + if v2: + self.tokenizer = self._load_tokenizer( + CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir + ) + else: + self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + + if max_length is None: + self.max_length = self.tokenizer.model_max_length + else: + self.max_length = max_length + 2 + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] + + +class SdTextEncodingStrategy(TextEncodingStrategy): + def __init__(self, clip_skip: Optional[int] = None) -> None: + self.clip_skip = clip_skip + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + text_encoder = models[0] + tokens = tokens[0] + sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy + + # tokens: b,n,77 + b_size = tokens.size()[0] + max_token_length = tokens.size()[1] * tokens.size()[2] + model_max_length = sd_tokenize_strategy.tokenizer.model_max_length + tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77 + + if self.clip_skip is None: + encoder_hidden_states = text_encoder(tokens)[0] + else: + enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) + + if max_token_length != model_max_length: + v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id + if not v1: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token: + # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + + return [encoder_hidden_states] + + +class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): + # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix. + # and we keep the old npz for the backward compatibility. + + SD_OLD_LATENTS_NPZ_SUFFIX = ".npz" + SD_LATENTS_NPZ_SUFFIX = "_sd.npz" + SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz" + + def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + self.sd = sd + self.suffix = ( + SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX + ) + + def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + # does not include old npz + npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + self.suffix) + if len(npz_file) == 0: + return None, None + w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + # support old .npz + old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX + if os.path.exists(old_npz_file): + return old_npz_file + return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample() + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py new file mode 100644 index 000000000..42630ab22 --- /dev/null +++ b/library/strategy_sd3.py @@ -0,0 +1,229 @@ +import os +import glob +from typing import Any, List, Optional, Tuple, Union +import torch +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast + +from library import sd3_utils, train_util +from library import sd3_models +from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14" +CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" +T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" + + +class Sd3TokenizeStrategy(TokenizeStrategy): + def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None: + self.t5xxl_max_length = t5xxl_max_length + self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + + l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + + l_tokens = l_tokens["input_ids"] + g_tokens = g_tokens["input_ids"] + t5_tokens = t5_tokens["input_ids"] + + return [l_tokens, g_tokens, t5_tokens] + + +class Sd3TextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + pass + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + clip_l, clip_g, t5xxl = models + + l_tokens, g_tokens, t5_tokens = tokens + if l_tokens is None: + assert g_tokens is None, "g_tokens must be None if l_tokens is None" + lg_out = None + else: + assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + l_out, l_pooled = clip_l(l_tokens) + g_out, g_pooled = clip_g(g_tokens) + lg_out = torch.cat([l_out, g_out], dim=-1) + + if t5xxl is not None and t5_tokens is not None: + t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096] + else: + t5_out = None + + lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None + return [lg_out, t5_out, lg_pooled] + + def concat_encodings( + self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + if t5_out is None: + t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype) + return torch.cat([lg_out, t5_out], dim=-2), lg_pooled + + +class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz" + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, abs_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(self.get_outputs_npz_path(abs_path)): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(self.get_outputs_npz_path(abs_path)) + if "clip_l" not in npz or "clip_g" not in npz: + return False + if "clip_l_pool" not in npz or "clip_g_pool" not in npz: + return False + # t5xxl is optional + except Exception as e: + logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + lg_out = data["lg_out"] + lg_pooled = data["lg_pooled"] + t5_out = data["t5_out"] if "t5_out" in data else None + return [lg_out, t5_out, lg_pooled] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + captions = [info.caption for info in infos] + + clip_l_tokens, clip_g_tokens, t5xxl_tokens = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + lg_out, t5_out, lg_pooled = text_encoding_strategy.encode_tokens( + tokenize_strategy, models, [clip_l_tokens, clip_g_tokens, t5xxl_tokens] + ) + + if lg_out.dtype == torch.bfloat16: + lg_out = lg_out.float() + if lg_pooled.dtype == torch.bfloat16: + lg_pooled = lg_pooled.float() + if t5_out is not None and t5_out.dtype == torch.bfloat16: + t5_out = t5_out.float() + + lg_out = lg_out.cpu().numpy() + lg_pooled = lg_pooled.cpu().numpy() + if t5_out is not None: + t5_out = t5_out.cpu().numpy() + + for i, info in enumerate(infos): + lg_out_i = lg_out[i] + t5_out_i = t5_out[i] if t5_out is not None else None + lg_pooled_i = lg_pooled[i] + + if self.cache_to_disk: + kwargs = {} + if t5_out is not None: + kwargs["t5_out"] = t5_out_i + np.savez(info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, **kwargs) + else: + info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i) + + +class Sd3LatentsCachingStrategy(LatentsCachingStrategy): + SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) + if len(npz_file) == 0: + return None, None + w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) + + +if __name__ == "__main__": + # test code for Sd3TokenizeStrategy + # tokenizer = sd3_models.SD3Tokenizer() + strategy = Sd3TokenizeStrategy(256) + text = "hello world" + + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + # print(l_tokens.shape) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + texts = ["hello world", "the quick brown fox jumps over the lazy dog"] + l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens_2 = strategy.t5xxl( + texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + print(l_tokens_2) + print(g_tokens_2) + print(t5_tokens_2) + + # compare + print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0])) + print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0])) + print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0])) + + text = ",".join(["hello world! this is long text"] * 50) + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + print(f"model max length l: {strategy.clip_l.model_max_length}") + print(f"model max length g: {strategy.clip_g.model_max_length}") + print(f"model max length t5: {strategy.t5xxl.model_max_length}") diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py new file mode 100644 index 000000000..a4513336d --- /dev/null +++ b/library/strategy_sdxl.py @@ -0,0 +1,247 @@ +import os +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection +from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy + + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +TOKENIZER1_PATH = "openai/clip-vit-large-patch14" +TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + + +class SdxlTokenizeStrategy(TokenizeStrategy): + def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: + self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir) + self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir) + self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2 + + if max_length is None: + self.max_length = self.tokenizer1.model_max_length + else: + self.max_length = max_length + 2 + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + return ( + torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0), + torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0), + ) + + +class SdxlTextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + pass + + def _pool_workaround( + self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int + ): + r""" + workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output + instead of the hidden states for the EOS token + If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output + + Original code from CLIP's pooling function: + + \# text_embeds.shape = [batch_size, sequence_length, transformer.width] + \# take features from the eot embedding (eot_token is the highest number in each sequence) + \# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + """ + + # input_ids: b*n,77 + # find index for EOS token + + # Following code is not working if one of the input_ids has multiple EOS tokens (very odd case) + # eos_token_index = torch.where(input_ids == eos_token_id)[1] + # eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # Create a mask where the EOS tokens are + eos_token_mask = (input_ids == eos_token_id).int() + + # Use argmax to find the last index of the EOS token for each element in the batch + eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine + eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # get hidden states for EOS token + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index + ] + + # apply projection: projection may be of different dtype than last_hidden_state + pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype)) + pooled_output = pooled_output.to(last_hidden_state.dtype) + + return pooled_output + + def _get_hidden_states_sdxl( + self, + input_ids1: torch.Tensor, + input_ids2: torch.Tensor, + tokenizer1: CLIPTokenizer, + tokenizer2: CLIPTokenizer, + text_encoder1: Union[CLIPTextModel, torch.nn.Module], + text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module], + unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None, + ): + # input_ids: b,n,77 -> b*n, 77 + b_size = input_ids1.size()[0] + max_token_length = input_ids1.size()[1] * input_ids1.size()[2] + input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 + input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 + input_ids1 = input_ids1.to(text_encoder1.device) + input_ids2 = input_ids2.to(text_encoder2.device) + + # text_encoder1 + enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True) + hidden_states1 = enc_out["hidden_states"][11] + + # text_encoder2 + enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) + hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer + + # pool2 = enc_out["text_embeds"] + unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2 + pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) + + # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 + n_size = 1 if max_token_length is None else max_token_length // 75 + hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1])) + hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1])) + + if max_token_length is not None: + # bs*3, 77, 768 or 1024 + # encoder1: ... の三連を ... へ戻す + states_list = [hidden_states1[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer1.model_max_length): + states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # の後から の前まで + states_list.append(hidden_states1[:, -1].unsqueeze(1)) # + hidden_states1 = torch.cat(states_list, dim=1) + + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [hidden_states2[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer2.model_max_length): + chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # の後から 最後の前まで + # this causes an error: + # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation + # if i > 1: + # for j in range(len(chunk)): # batch_size + # if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり ...のパターン + # chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(hidden_states2[:, -1].unsqueeze(1)) # のどちらか + hidden_states2 = torch.cat(states_list, dim=1) + + # pool はnの最初のものを使う + pool2 = pool2[::n_size] + + return hidden_states1, hidden_states2, pool2 + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Args: + tokenize_strategy: TokenizeStrategy + models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)] + tokens: List of tokens, for text_encoder1 and text_encoder2 + """ + if len(models) == 2: + text_encoder1, text_encoder2 = models + unwrapped_text_encoder2 = None + else: + text_encoder1, text_encoder2, unwrapped_text_encoder2 = models + tokens1, tokens2 = tokens + sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy + tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2 + + hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl( + tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2 + ) + return [hidden_states1, hidden_states2, pool2] + + +class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, abs_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(self.get_outputs_npz_path(abs_path)): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(self.get_outputs_npz_path(abs_path)) + if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + hidden_state1 = data["hidden_state1"] + hidden_state2 = data["hidden_state2"] + pool2 = data["pool2"] + return [hidden_state1, hidden_state2, pool2] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy + captions = [info.caption for info in infos] + + tokens1, tokens2 = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, [tokens1, tokens2] + ) + if hidden_state1.dtype == torch.bfloat16: + hidden_state1 = hidden_state1.float() + if hidden_state2.dtype == torch.bfloat16: + hidden_state2 = hidden_state2.float() + if pool2.dtype == torch.bfloat16: + pool2 = pool2.float() + + hidden_state1 = hidden_state1.cpu().numpy() + hidden_state2 = hidden_state2.cpu().numpy() + pool2 = pool2.cpu().numpy() + + for i, info in enumerate(infos): + hidden_state1_i = hidden_state1[i] + hidden_state2_i = hidden_state2[i] + pool2_i = pool2[i] + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + hidden_state1=hidden_state1_i, + hidden_state2=hidden_state2_i, + pool2=pool2_i, + ) + else: + info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i] diff --git a/library/train_util.py b/library/train_util.py index 7af0070e1..a747e0478 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -12,6 +12,7 @@ import shutil import time from typing import ( + Any, Dict, List, NamedTuple, @@ -34,6 +35,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device +from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy init_ipex() @@ -81,10 +83,6 @@ # from library.hypernetwork import replace_attentions_for_hypernetwork from library.original_unet import UNet2DConditionModel -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ - HIGH_VRAM = False # checkpointファイル名 @@ -148,18 +146,24 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.image_size: Tuple[int, int] = None self.resized_size: Tuple[int, int] = None self.bucket_reso: Tuple[int, int] = None - self.latents: torch.Tensor = None - self.latents_flipped: torch.Tensor = None - self.latents_npz: str = None - self.latents_original_size: Tuple[int, int] = None # original image size, not latents size - self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size - self.cond_img_path: str = None + self.latents: Optional[torch.Tensor] = None + self.latents_flipped: Optional[torch.Tensor] = None + self.latents_npz: Optional[str] = None # set in cache_latents + self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size + self.latents_crop_ltrb: Optional[Tuple[int, int]] = ( + None # crop left top right bottom in original pixel size, not latents size + ) + self.cond_img_path: Optional[str] = None self.image: Optional[Image.Image] = None # optional, original PIL Image - # SDXL, optional - self.text_encoder_outputs_npz: Optional[str] = None + self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs + + # new + self.text_encoder_outputs: Optional[List[torch.Tensor]] = None + # old self.text_encoder_outputs1: Optional[torch.Tensor] = None self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None + self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime @@ -359,47 +363,6 @@ def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarra return self.color_aug if use_color_aug else None -class LatentsCachingStrategy: - _strategy = None # strategy instance: actual strategy class - - def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: - self._cache_to_disk = cache_to_disk - self._batch_size = batch_size - self.skip_disk_cache_validity_check = skip_disk_cache_validity_check - - @classmethod - def set_strategy(cls, strategy): - if cls._strategy is not None: - raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") - cls._strategy = strategy - - @classmethod - def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: - return cls._strategy - - @property - def cache_to_disk(self): - return self._cache_to_disk - - @property - def batch_size(self): - return self._batch_size - - def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - raise NotImplementedError - - def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str: - raise NotImplementedError - - def is_disk_cached_latents_expected( - self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool - ) -> bool: - raise NotImplementedError - - def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): - raise NotImplementedError - - class BaseSubset: def __init__( self, @@ -639,17 +602,12 @@ def __eq__(self, other) -> bool: class BaseDataset(torch.utils.data.Dataset): def __init__( self, - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], - max_token_length: int, resolution: Optional[Tuple[int, int]], network_multiplier: float, debug_dataset: bool, ) -> None: super().__init__() - self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] - - self.max_token_length = max_token_length # width/height is used when enable_bucket==False self.width, self.height = (None, None) if resolution is None else resolution self.network_multiplier = network_multiplier @@ -670,8 +628,6 @@ def __init__( self.bucket_no_upscale = None self.bucket_info = None # for metadata - self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2 - self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ self.current_step: int = 0 @@ -690,6 +646,15 @@ def __init__( # caching self.caching_mode = None # None, 'latents', 'text' + + self.tokenize_strategy = None + self.text_encoder_output_caching_strategy = None + self.latents_caching_strategy = None + + def set_current_strategies(self): + self.tokenize_strategy = TokenizeStrategy.get_strategy() + self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() + self.latents_caching_strategy = LatentsCachingStrategy.get_strategy() def set_seed(self, seed): self.seed = seed @@ -979,22 +944,6 @@ def make_buckets(self): for batch_index in range(batch_count): self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) - # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す - #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる - # - # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは - # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう - # # そのためバッチサイズを画像種類までに制限する - # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない? - # # TO DO 正則化画像をepochまたがりで利用する仕組み - # num_of_image_types = len(set(bucket)) - # bucket_batch_size = min(self.batch_size, num_of_image_types) - # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) - # # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count) - # for batch_index in range(batch_count): - # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) - # ↑ここまで - self.shuffle_buckets() self._length = len(self.buckets_indices) @@ -1027,12 +976,13 @@ def is_text_encoder_output_cacheable(self): ] ) - def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy): + def new_cache_latents(self, model: Any, is_main_process: bool): r""" a brand new method to cache latents. This method caches latents with caching strategy. normal cache_latents method is used by default, but this method is used when caching strategy is specified. """ logger.info("caching latents with caching strategy.") + caching_strategy = LatentsCachingStrategy.get_strategy() image_infos = list(self.image_data.values()) # sort by resolution @@ -1088,7 +1038,7 @@ def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCach logger.info("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) - caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと @@ -1145,6 +1095,56 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc for batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + r""" + a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy. + """ + tokenize_strategy = TokenizeStrategy.get_strategy() + text_encoding_strategy = TextEncodingStrategy.get_strategy() + caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() + batch_size = caching_strategy.batch_size or self.batch_size + + # if cache to disk, don't cache TE outputs in non-main process + if caching_strategy.cache_to_disk and not is_main_process: + return + + logger.info("caching Text Encoder outputs with caching strategy.") + image_infos = list(self.image_data.values()) + + # split by resolution + batches = [] + batch = [] + logger.info("checking cache validity...") + for info in tqdm(image_infos): + te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) + + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + info.text_encoder_outputs_npz = te_out_npz + cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) + if cache_available: # do not add to batch + continue + + batch.append(info) + + # if number of data in batch is enough, flush the batch + if len(batch) >= batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + if len(batches) == 0: + logger.info("no Text Encoder outputs to cache") + return + + # iterate batches + logger.info("caching Text Encoder outputs...") + for batch in tqdm(batches, smoothing=1, total=len(batches)): + # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + caching_strategy.cache_batch_outputs(tokenize_strategy, models, text_encoding_strategy, batch) + # if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype # this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset # to support SD1/2, it needs a flag for v2, but it is postponed @@ -1188,6 +1188,8 @@ def cache_text_encoder_outputs_common( # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching text encoder outputs.") + tokenize_strategy = TokenizeStrategy.get_strategy() + if batch_size is None: batch_size = self.batch_size @@ -1229,7 +1231,7 @@ def cache_text_encoder_outputs_common( input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) batch.append((info, input_ids1, input_ids2)) else: - l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption) + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(info.caption) batch.append((info, l_tokens, g_tokens, t5_tokens)) if len(batch) >= batch_size: @@ -1347,7 +1349,6 @@ def __getitem__(self, index): loss_weights = [] captions = [] input_ids_list = [] - input_ids2_list = [] latents_list = [] alpha_mask_list = [] images = [] @@ -1355,16 +1356,14 @@ def __getitem__(self, index): crop_top_lefts = [] target_sizes_hw = [] flippeds = [] # 変数名が微妙 - text_encoder_outputs1_list = [] - text_encoder_outputs2_list = [] - text_encoder_pool2_list = [] + text_encoder_outputs_list = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] - loss_weights.append( - self.prior_loss_weight if image_info.is_reg else 1.0 - ) # in case of fine tuning, is_reg is always False + + # in case of fine tuning, is_reg is always False + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance @@ -1381,7 +1380,9 @@ def __getitem__(self, index): image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( + self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz) + ) if flipped: latents = flipped_latents alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem @@ -1470,75 +1471,67 @@ def __getitem__(self, index): # captionとtext encoder outputを処理する caption = image_info.caption # default - if image_info.text_encoder_outputs1 is not None: - text_encoder_outputs1_list.append(image_info.text_encoder_outputs1) - text_encoder_outputs2_list.append(image_info.text_encoder_outputs2) - text_encoder_pool2_list.append(image_info.text_encoder_pool2) - captions.append(caption) + + tokenization_required = ( + self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial + ) + text_encoder_outputs = None + input_ids = None + + if image_info.text_encoder_outputs is not None: + # cached + text_encoder_outputs = image_info.text_encoder_outputs elif image_info.text_encoder_outputs_npz is not None: - text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk( + # on disk + text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( image_info.text_encoder_outputs_npz ) - text_encoder_outputs1_list.append(text_encoder_outputs1) - text_encoder_outputs2_list.append(text_encoder_outputs2) - text_encoder_pool2_list.append(text_encoder_pool2) - captions.append(caption) else: - caption = self.process_caption(subset, image_info.caption) - if self.XTI_layers: - caption_layer = [] - for layer in self.XTI_layers: - token_strings_from = " ".join(self.token_strings) - token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) - caption_ = caption.replace(token_strings_from, token_strings_to) - caption_layer.append(caption_) - captions.append(caption_layer) - else: - captions.append(caption) + tokenization_required = True + text_encoder_outputs_list.append(text_encoder_outputs) - if not self.token_padding_disabled: # this option might be omitted in future - # TODO get_input_ids must support SD3 - if self.XTI_layers: - token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) - else: - token_caption = self.get_input_ids(caption, self.tokenizers[0]) - input_ids_list.append(token_caption) + if tokenization_required: + caption = self.process_caption(subset, image_info.caption) + input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension + # if self.XTI_layers: + # caption_layer = [] + # for layer in self.XTI_layers: + # token_strings_from = " ".join(self.token_strings) + # token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + # caption_ = caption.replace(token_strings_from, token_strings_to) + # caption_layer.append(caption_) + # captions.append(caption_layer) + # else: + # captions.append(caption) + + # if not self.token_padding_disabled: # this option might be omitted in future + # # TODO get_input_ids must support SD3 + # if self.XTI_layers: + # token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) + # else: + # token_caption = self.get_input_ids(caption, self.tokenizers[0]) + # input_ids_list.append(token_caption) + + # if len(self.tokenizers) > 1: + # if self.XTI_layers: + # token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) + # else: + # token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) + # input_ids2_list.append(token_caption2) + + input_ids_list.append(input_ids) + captions.append(caption) - if len(self.tokenizers) > 1: - if self.XTI_layers: - token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) - else: - token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) - input_ids2_list.append(token_caption2) + def none_or_stack_elements(tensors_list, converter): + # [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)] + if len(tensors_list) == 0 or tensors_list[0] == None or len(tensors_list[0]) == 0 or tensors_list[0][0] is None: + return None + return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))] example = {} example["loss_weights"] = torch.FloatTensor(loss_weights) - - if len(text_encoder_outputs1_list) == 0: - if self.token_padding_disabled: - # padding=True means pad in the batch - example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids - if len(self.tokenizers) > 1: - example["input_ids2"] = self.tokenizer[1]( - captions, padding=True, truncation=True, return_tensors="pt" - ).input_ids - else: - example["input_ids2"] = None - else: - example["input_ids"] = torch.stack(input_ids_list) - example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None - example["text_encoder_outputs1_list"] = None - example["text_encoder_outputs2_list"] = None - example["text_encoder_pool2_list"] = None - else: - example["input_ids"] = None - example["input_ids2"] = None - # # for assertion - # example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions]) - # example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions]) - example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list) - example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list) - example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list) + example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor) + example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x) # if one of alpha_masks is not None, we need to replace None with ones none_or_not = [x is None for x in alpha_mask_list] @@ -1652,8 +1645,6 @@ def __init__( self, subsets: Sequence[DreamBoothSubset], batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -1664,7 +1655,7 @@ def __init__( prior_loss_weight: float, debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -1750,10 +1741,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # new caching: get image size from cache files strategy = LatentsCachingStrategy.get_strategy() if strategy is not None: - logger.info("get image size from cache files") + logger.info("get image size from name of cache files") size_set_count = 0 for i, img_path in enumerate(tqdm(img_paths)): - w, h = strategy.get_image_size_from_image_absolute_path(img_path) + w, h = strategy.get_image_size_from_disk_cache_path(img_path) if w is not None and h is not None: sizes[i] = [w, h] size_set_count += 1 @@ -1886,8 +1877,6 @@ def __init__( self, subsets: Sequence[FineTuningSubset], batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -1897,7 +1886,7 @@ def __init__( bucket_no_upscale: bool, debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset) self.batch_size = batch_size @@ -2111,8 +2100,6 @@ def __init__( self, subsets: Sequence[ControlNetSubset], batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -2122,7 +2109,7 @@ def __init__( bucket_no_upscale: bool, debug_dataset: float, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset) db_subsets = [] for subset in subsets: @@ -2160,8 +2147,6 @@ def __init__( self.dreambooth_dataset_delegate = DreamBoothDataset( db_subsets, batch_size, - tokenizer, - max_token_length, resolution, network_multiplier, enable_bucket, @@ -2221,6 +2206,9 @@ def __init__( self.conditioning_image_transforms = IMAGE_TRANSFORMS + def set_current_strategies(self): + return self.dreambooth_dataset_delegate.set_current_strategies() + def make_buckets(self): self.dreambooth_dataset_delegate.make_buckets() self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager @@ -2229,6 +2217,12 @@ def make_buckets(self): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + def new_cache_latents(self, model: Any, is_main_process: bool): + return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process) + + def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process) + def __len__(self): return self.dreambooth_dataset_delegate.__len__() @@ -2314,6 +2308,13 @@ def add_replacement(self, str_from, str_to): # for dataset in self.datasets: # dataset.make_buckets() + def set_text_encoder_output_caching_strategy(self, strategy: TextEncoderOutputsCachingStrategy): + """ + DataLoader is run in multiple processes, so we need to set the strategy manually. + """ + for dataset in self.datasets: + dataset.set_text_encoder_output_caching_strategy(strategy) + def enable_XTI(self, *args, **kwargs): for dataset in self.datasets: dataset.enable_XTI(*args, **kwargs) @@ -2323,10 +2324,10 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc logger.info(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) - def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy): + def new_cache_latents(self, model: Any, is_main_process: bool): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.new_cache_latents(is_main_process, strategy) + dataset.new_cache_latents(model, is_main_process) def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True @@ -2344,6 +2345,11 @@ def cache_text_encoder_outputs_sd3( tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size ) + def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.new_cache_text_encoder_outputs(models, is_main_process) + def set_caching_mode(self, caching_mode): for dataset in self.datasets: dataset.set_caching_mode(caching_mode) @@ -2358,6 +2364,10 @@ def is_latent_cacheable(self) -> bool: def is_text_encoder_output_cacheable(self) -> bool: return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets]) + def set_current_strategies(self): + for dataset in self.datasets: + dataset.set_current_strategies() + def set_current_epoch(self, epoch): for dataset in self.datasets: dataset.set_current_epoch(epoch) @@ -2411,34 +2421,34 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) # TODO update to use CachingStrategy -def load_latents_from_disk( - npz_path, -) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: - npz = np.load(npz_path) - if "latents" not in npz: - raise ValueError(f"error: npz is old format. please re-generate {npz_path}") - - latents = npz["latents"] - original_size = npz["original_size"].tolist() - crop_ltrb = npz["crop_ltrb"].tolist() - flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None - return latents, original_size, crop_ltrb, flipped_latents, alpha_mask - - -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): - kwargs = {} - if flipped_latents_tensor is not None: - kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() - if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() - np.savez( - npz_path, - latents=latents_tensor.float().cpu().numpy(), - original_size=np.array(original_size), - crop_ltrb=np.array(crop_ltrb), - **kwargs, - ) +# def load_latents_from_disk( +# npz_path, +# ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: +# npz = np.load(npz_path) +# if "latents" not in npz: +# raise ValueError(f"error: npz is old format. please re-generate {npz_path}") + +# latents = npz["latents"] +# original_size = npz["original_size"].tolist() +# crop_ltrb = npz["crop_ltrb"].tolist() +# flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None +# alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None +# return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + + +# def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): +# kwargs = {} +# if flipped_latents_tensor is not None: +# kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() +# if alpha_mask is not None: +# kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() +# np.savez( +# npz_path, +# latents=latents_tensor.float().cpu().numpy(), +# original_size=np.array(original_size), +# crop_ltrb=np.array(crop_ltrb), +# **kwargs, +# ) def debug_dataset(train_dataset, show_input_ids=False): @@ -2465,12 +2475,12 @@ def debug_dataset(train_dataset, show_input_ids=False): example = train_dataset[idx] if example["latents"] is not None: logger.info(f"sample has latents from npz file: {example['latents'].size()}") - for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate( + for j, (ik, cap, lw, orgsz, crptl, trgsz, flpdz) in enumerate( zip( example["image_keys"], example["captions"], example["loss_weights"], - example["input_ids"], + # example["input_ids"], example["original_sizes_hw"], example["crop_top_lefts"], example["target_sizes_hw"], @@ -2483,10 +2493,10 @@ def debug_dataset(train_dataset, show_input_ids=False): if "network_multipliers" in example: print(f"network multiplier: {example['network_multipliers'][j]}") - if show_input_ids: - logger.info(f"input ids: {iid}") - if "input_ids2" in example: - logger.info(f"input ids2: {example['input_ids2'][j]}") + # if show_input_ids: + # logger.info(f"input ids: {iid}") + # if "input_ids2" in example: + # logger.info(f"input ids2: {example['input_ids2'][j]}") if example["images"] is not None: im = example["images"][j] logger.info(f"image size: {im.size()}") @@ -2555,8 +2565,8 @@ def glob_images_pathlib(dir_path, recursive): class MinimalDataset(BaseDataset): - def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False): - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + def __init__(self, resolution, network_multiplier, debug_dataset=False): + super().__init__(resolution, network_multiplier, debug_dataset) self.num_train_images = 0 # update in subclass self.num_reg_images = 0 # update in subclass @@ -2773,14 +2783,15 @@ def cache_batch_latents( raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk( - info.latents_npz, - latent, - info.latents_original_size, - info.latents_crop_ltrb, - flipped_latent, - alpha_mask, - ) + # save_latents_to_disk( + # info.latents_npz, + # latent, + # info.latents_original_size, + # info.latents_crop_ltrb, + # flipped_latent, + # alpha_mask, + # ) + pass else: info.latents = latent if flip_aug: @@ -4662,33 +4673,6 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): ) -def load_tokenizer(args: argparse.Namespace): - logger.info("prepare tokenizer") - original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH - - tokenizer: CLIPTokenizer = None - if args.tokenizer_cache_dir: - local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) - if os.path.exists(local_tokenizer_path): - logger.info(f"load tokenizer from cache: {local_tokenizer_path}") - tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 - - if tokenizer is None: - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(original_path) - - if hasattr(args, "max_token_length") and args.max_token_length is not None: - logger.info(f"update token length: {args.max_token_length}") - - if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") - tokenizer.save_pretrained(local_tokenizer_path) - - return tokenizer - - def prepare_accelerator(args: argparse.Namespace): """ this function also prepares deepspeed plugin @@ -5550,6 +5534,7 @@ def sample_images_common( ): """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した + TODO Use strategies here """ if steps == 0: diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index ffa0d46de..e9e61af1b 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) -from library import sd3_models, sd3_utils +from library import sd3_models, sd3_utils, strategy_sd3 def get_noise(seed, latent): @@ -145,6 +145,7 @@ def do_sample( parser.add_argument("--clip_g", type=str, required=False) parser.add_argument("--clip_l", type=str, required=False) parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77") parser.add_argument("--prompt", type=str, default="A photo of a cat") # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders parser.add_argument("--negative_prompt", type=str, default="") @@ -247,7 +248,7 @@ def do_sample( # load tokenizers logger.info("Loading tokenizers...") - tokenizer = sd3_models.SD3Tokenizer(use_t5xxl) # combined tokenizer + tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) # load models # logger.info("Create MMDiT from SD3 checkpoint...") @@ -320,12 +321,19 @@ def do_sample( # prepare embeddings logger.info("Encoding prompts...") - # embeds, pooled_embed - lg_out, t5_out, pooled = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl) - cond = torch.cat([lg_out, t5_out], dim=-2), pooled + encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() - lg_out, t5_out, pooled = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl) - neg_cond = torch.cat([lg_out, t5_out], dim=-2), pooled + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.prompt) + lg_out, t5_out, pooled = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens] + ) + cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.negative_prompt) + lg_out, t5_out, pooled = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens] + ) + neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # generate image logger.info("Generating image...") diff --git a/sd3_train.py b/sd3_train.py index f34e47124..617e30271 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -17,7 +17,7 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils +from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3 from library.sdxl_train_util import match_mixed_precision # , sdxl_model_util @@ -69,10 +69,22 @@ def train(args): # not args.train_text_encoder # ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" - # training without text encoder cache is not supported - assert ( - args.cache_text_encoder_outputs - ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" + # # training without text encoder cache is not supported: because T5XXL must be cached + # assert ( + # args.cache_text_encoder_outputs + # ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" + + assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), ( + "when training text encoder, text encoder outputs must not be cached (except for T5XXL)" + + " / text encoderの学習時はtext encoderの出力はキャッシュできません(t5xxlのみキャッシュすることは可能です)" + ) + + if args.use_t5xxl_cache_only and not args.cache_text_encoder_outputs: + logger.warning( + "use_t5xxl_cache_only is enabled, so cache_text_encoder_outputs is automatically enabled." + + " / use_t5xxl_cache_onlyが有効なため、cache_text_encoder_outputsも自動的に有効になります" + ) + args.cache_text_encoder_outputs = True # if args.block_lr: # block_lrs = [float(lr) for lr in args.block_lr.split(",")] @@ -88,17 +100,17 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - # load tokenizer - sd3_tokenizer = sd3_models.SD3Tokenizer() - - # prepare caching strategy - if args.new_caching: - latents_caching_strategy = sd3_train_utils.Sd3LatentsCachingStrategy( + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check ) - else: - latents_caching_strategy = None - train_util.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # load tokenizer and prepare tokenize strategy + sd3_tokenizer = sd3_models.SD3Tokenizer(t5xxl_max_length=args.t5xxl_max_token_length) + sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy) # データセットを準備する if args.dataset_class is None: @@ -153,6 +165,16 @@ def train(args): train_dataset_group.verify_bucket_reso_steps(8) # TODO これでいいか確認 if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + False, + False, + ) + ) + train_dataset_group.set_current_strategies() train_util.debug_dataset(train_dataset_group, True) return if len(train_dataset_group) == 0: @@ -215,19 +237,8 @@ def train(args): vae.requires_grad_(False) vae.eval() - if not args.new_caching: - vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible - with torch.no_grad(): - train_dataset_group.cache_latents( - vae_wrapper, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - file_suffix="_sd3.npz", - ) - else: - latents_caching_strategy.set_vae(vae) - train_dataset_group.new_cache_latents(accelerator.is_main_process, latents_caching_strategy) + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) @@ -246,60 +257,70 @@ def train(args): t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load) # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) + # should be deleted after caching text encoder outputs when not training text encoder + # this strategy should not be used other than this process + text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # 学習を準備する:モデルを適切な状態にする train_clip_l = False train_clip_g = False train_t5xxl = False - # if args.train_text_encoder: - # # TODO each option for two text encoders? - # accelerator.print("enable text encoder training") - # if args.gradient_checkpointing: - # text_encoder1.gradient_checkpointing_enable() - # text_encoder2.gradient_checkpointing_enable() - # lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train - # lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train - # train_clip_l = lr_te1 != 0 - # train_clip_g = lr_te2 != 0 - - # # caching one text encoder output is not supported - # if not train_clip_l: - # text_encoder1.to(weight_dtype) - # if not train_clip_g: - # text_encoder2.to(weight_dtype) - # text_encoder1.requires_grad_(train_clip_l) - # text_encoder2.requires_grad_(train_clip_g) - # text_encoder1.train(train_clip_l) - # text_encoder2.train(train_clip_g) - # else: - clip_l.to(weight_dtype) - clip_g.to(weight_dtype) - clip_l.requires_grad_(False) - clip_g.requires_grad_(False) - clip_l.eval() - clip_g.eval() + if args.train_text_encoder: + accelerator.print("enable text encoder training") + if args.gradient_checkpointing: + clip_l.gradient_checkpointing_enable() + clip_g.gradient_checkpointing_enable() + lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train + train_clip_l = lr_te1 != 0 + train_clip_g = lr_te2 != 0 + + if not train_clip_l: + clip_l.to(weight_dtype) + if not train_clip_g: + clip_g.to(weight_dtype) + clip_l.requires_grad_(train_clip_l) + clip_g.requires_grad_(train_clip_g) + clip_l.train(train_clip_l) + clip_g.train(train_clip_g) + else: + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + clip_l.requires_grad_(False) + clip_g.requires_grad_(False) + clip_l.eval() + clip_g.eval() + if t5xxl is not None: t5xxl.to(t5xxl_dtype) t5xxl.requires_grad_(False) t5xxl.eval() - # TextEncoderの出力をキャッシュする + # cache text encoder outputs if args.cache_text_encoder_outputs: - # Text Encodes are eval and no grad - - with torch.no_grad(), accelerator.autocast(): - train_dataset_group.cache_text_encoder_outputs_sd3( - sd3_tokenizer, - (clip_l, clip_g, t5xxl), - (accelerator.device, accelerator.device, t5xxl_device), - None, - (None, None, None), - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - args.text_encoder_batch_size, - ) + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(t5xxl_device) + + text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + False, + train_clip_g or train_clip_l or args.use_t5xxl_cache_only, + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + clip_l.to(accelerator.device, dtype=weight_dtype) + clip_g.to(accelerator.device, dtype=weight_dtype) + if t5xxl is not None: + t5xxl.to(t5xxl_device, dtype=t5xxl_dtype) - # TODO we can delete text encoders after caching + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process) accelerator.wait_for_everyone() # load MMDIT @@ -332,11 +353,11 @@ def train(args): # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) # if train_clip_l: - # training_models.append(text_encoder1) - # params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + # training_models.append(clip_l) + # params_to_optimize.append({"params": list(clip_l.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) # if train_clip_g: - # training_models.append(text_encoder2) - # params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + # training_models.append(clip_g) + # params_to_optimize.append({"params": list(clip_g.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) # calculate number of trainable parameters n_params = 0 @@ -344,7 +365,7 @@ def train(args): for p in group["params"]: n_params += p.numel() - accelerator.print(f"train mmdit: {train_mmdit}") # , text_encoder1: {train_clip_l}, text_encoder2: {train_clip_g}") + accelerator.print(f"train mmdit: {train_mmdit}") # , clip_l: {train_clip_l}, clip_g: {train_clip_g}") accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of trainable parameters: {n_params}") @@ -398,7 +419,11 @@ def train(args): else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -455,8 +480,8 @@ def train(args): # TODO check if this is necessary. SD3 uses pool for clip_l and clip_g # # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer # if train_clip_l: - # text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) - # text_encoder1.text_model.final_layer_norm.requires_grad_(False) + # clip_l.text_model.encoder.layers[-1].requires_grad_(False) + # clip_l.text_model.final_layer_norm.requires_grad_(False) # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する if args.cache_text_encoder_outputs: @@ -484,9 +509,8 @@ def train(args): ds_model = deepspeed_utils.prepare_deepspeed_model( args, mmdit=mmdit, - # mmdie=mmdit if train_mmdit else None, - # text_encoder1=text_encoder1 if train_clip_l else None, - # text_encoder2=text_encoder2 if train_clip_g else None, + clip_l=clip_l if train_clip_l else None, + clip_g=clip_g if train_clip_g else None, ) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -498,10 +522,10 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい if train_mmdit: mmdit = accelerator.prepare(mmdit) - # if train_clip_l: - # text_encoder1 = accelerator.prepare(text_encoder1) - # if train_clip_g: - # text_encoder2 = accelerator.prepare(text_encoder2) + if train_clip_l: + clip_l = accelerator.prepare(clip_l) + if train_clip_g: + clip_g = accelerator.prepare(clip_g) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -613,7 +637,7 @@ def optimizer_hook(parameter: torch.Tensor): # # For --sample_at_first # sd3_train_utils.sample_images( - # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], mmdit + # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [clip_l, clip_g], mmdit # ) # following function will be moved to sd3_train_utils @@ -666,6 +690,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): return weighting loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -687,37 +712,45 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): # encode images to latents. images are [-1, 1] latents = vae.encode(batch["images"].to(vae_dtype)).to(weight_dtype) - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + # latents = latents * sdxl_model_util.VAE_SCALE_FACTOR latents = sd3_models.SDVAE.process_in(latents) - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - # not cached, get text encoder outputs - # XXX This does not work yet - input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl = batch["input_ids"] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + lg_out, t5_out, lg_pooled = text_encoder_outputs_list + if args.use_t5xxl_cache_only: + lg_out = None + lg_pooled = None + else: + lg_out = None + t5_out = None + lg_pooled = None + + if lg_out is None or (train_clip_l or train_clip_g): + # not cached or training, so get from text encoders + input_ids_clip_l, input_ids_clip_g, _ = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # TODO support weighted captions - # TODO support length > 75 input_ids_clip_l = input_ids_clip_l.to(accelerator.device) input_ids_clip_g = input_ids_clip_g.to(accelerator.device) - input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) + lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, [clip_l, clip_g, None], [input_ids_clip_l, input_ids_clip_g, None] + ) - # get text encoder outputs: outputs are concatenated - context, pool = sd3_utils.get_cond_from_tokens( - input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl, clip_l, clip_g, t5xxl + if t5_out is None: + _, _, input_ids_t5xxl = batch["input_ids_list"] + with torch.no_grad(): + input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None + _, t5_out, _ = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl] ) - else: - # encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - # encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - # pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) - # TODO this reuses SDXL keys, it should be fixed - lg_out = batch["text_encoder_outputs1_list"] - t5_out = batch["text_encoder_outputs2_list"] - pool = batch["text_encoder_pool2_list"] - context = torch.cat([lg_out, t5_out], dim=-2) + + context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps @@ -748,13 +781,13 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): if torch.any(torch.isnan(context)): accelerator.print("NaN found in context, replacing with zeros") context = torch.nan_to_num(context, 0, out=context) - if torch.any(torch.isnan(pool)): + if torch.any(torch.isnan(lg_pooled)): accelerator.print("NaN found in pool, replacing with zeros") - pool = torch.nan_to_num(pool, 0, out=pool) + lg_pooled = torch.nan_to_num(lg_pooled, 0, out=lg_pooled) # call model with accelerator.autocast(): - model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool) + model_pred = mmdit(noisy_model_input, timesteps, context=context, y=lg_pooled) # Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Preconditioning of the model outputs. @@ -806,7 +839,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): # accelerator.device, # vae, # [tokenizer1, tokenizer2], - # [text_encoder1, text_encoder2], + # [clip_l, clip_g], # mmdit, # ) @@ -875,7 +908,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): # accelerator.device, # vae, # [tokenizer1, tokenizer2], - # [text_encoder1, text_encoder2], + # [clip_l, clip_g], # mmdit, # ) @@ -924,7 +957,19 @@ def setup_parser() -> argparse.ArgumentParser: custom_train_functions.add_custom_train_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) - # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + parser.add_argument( + "--train_text_encoder", action="store_true", help="train text encoder (CLIP-L and G) / text encoderも学習する" + ) + # parser.add_argument("--train_t5xxl", action="store_true", help="train T5-XXL / T5-XXLも学習する") + parser.add_argument( + "--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする" + ) + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=None, + help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256", + ) # TE training is disabled temporarily # parser.add_argument( @@ -962,7 +1007,6 @@ def setup_parser() -> argparse.ArgumentParser: help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", ) - parser.add_argument("--new_caching", action="store_true", help="use new caching method / 新しいキャッシング方法を使う") parser.add_argument( "--skip_latents_validity_check", action="store_true", diff --git a/sdxl_train.py b/sdxl_train.py index ae92d6a3d..b6d4afd6a 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -17,7 +17,7 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import deepspeed_utils, sdxl_model_util +from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl import library.train_util as train_util @@ -124,7 +124,16 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] # will be removed in the future + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -166,10 +175,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2]) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -262,8 +271,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -276,6 +286,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): train_text_encoder1 = False train_text_encoder2 = False + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if args.train_text_encoder: # TODO each option for two text encoders? accelerator.print("enable text encoder training") @@ -307,16 +320,17 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(), accelerator.autocast(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) - accelerator.wait_for_everyone() + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + + accelerator.wait_for_everyone() if not cache_latents: vae.requires_grad_(False) @@ -403,7 +417,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -597,7 +615,7 @@ def optimizer_hook(parameter: torch.Tensor): # For --sample_at_first sdxl_train_util.sample_images( - accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet + accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet ) loss_recorder = train_util.LossRecorder() @@ -628,9 +646,15 @@ def optimizer_hook(parameter: torch.Tensor): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning # TODO support weighted captions @@ -646,39 +670,13 @@ def optimizer_hook(parameter: torch.Tensor): # else: input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) - # unwrap_model is fine for models not wrapped by accelerator - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, - accelerator=accelerator, + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] ) - else: - encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) - - # # verify that the text encoder outputs are correct - # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( - # args.max_token_length, - # batch["input_ids"].to(text_encoder1.device), - # batch["input_ids2"].to(text_encoder1.device), - # tokenizer1, - # tokenizer2, - # text_encoder1, - # text_encoder2, - # None if not args.full_fp16 else weight_dtype, - # ) - # b_size = encoder_hidden_states1.shape[0] - # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # logger.info("text encoder outputs verified") + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) # get size embeddings orig_size = batch["original_sizes_hw"] @@ -765,7 +763,7 @@ def optimizer_hook(parameter: torch.Tensor): global_step, accelerator.device, vae, - [tokenizer1, tokenizer2], + tokenizers, [text_encoder1, text_encoder2], unet, ) @@ -847,7 +845,7 @@ def optimizer_hook(parameter: torch.Tensor): global_step, accelerator.device, vae, - [tokenizer1, tokenizer2], + tokenizers, [text_encoder1, text_encoder2], unet, ) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 5ff060a9f..0eaec29b8 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -23,7 +23,16 @@ import accelerate from diffusers import DDPMScheduler, ControlNetModel from safetensors.torch import load_file -from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util +from library import ( + deepspeed_utils, + sai_model_spec, + sdxl_model_util, + sdxl_original_unet, + sdxl_train_util, + strategy_base, + strategy_sd, + strategy_sdxl, +) import library.model_util as model_util import library.train_util as train_util @@ -79,7 +88,14 @@ def train(args): args.seed = random.randint(0, 2**32) set_seed(args.seed) - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) @@ -106,7 +122,7 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) @@ -164,30 +180,30 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + accelerator.wait_for_everyone() # prepare ControlNet-LLLite @@ -242,7 +258,11 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -290,7 +310,7 @@ def train(args): unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) if isinstance(unet, DDP): - unet._set_static_graph() # avoid error for multiple use of the parameter + unet._set_static_graph() # avoid error for multiple use of the parameter if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる @@ -357,7 +377,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -409,27 +431,26 @@ def remove_model(old_ckpt_name): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] with torch.no_grad(): - # Get the text embedding for conditioning input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] ) - else: - encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) # get size embeddings orig_size = batch["original_sizes_hw"] diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 83969bb1d..67ccae62c 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,16 +1,21 @@ import argparse import torch +from accelerate import Accelerator from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() -from library import sdxl_model_util, sdxl_train_util, train_util +from library import sdxl_model_util, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, train_util import train_network from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + class SdxlNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() @@ -49,15 +54,32 @@ def load_target_model(self, args, weight_dtype, accelerator): return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet - def load_tokenizer(self, args): - tokenizer = sdxl_train_util.load_tokenizers(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): + return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + return latents_caching_strategy - def is_text_encoder_outputs_cached(self, args): - return args.cache_text_encoder_outputs + def get_text_encoding_strategy(self, args): + return strategy_sdxl.SdxlTextEncodingStrategy() + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + return text_encoders + [accelerator.unwrap_model(text_encoders[-1])] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False) + else: + return None def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype ): if args.cache_text_encoder_outputs: if not args.lowram: @@ -70,15 +92,13 @@ def cache_text_encoder_outputs_if_needed( clean_memory_on_device(accelerator.device) # When TE is not be trained, it will not be prepared so we need to use explicit autocast + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) with accelerator.autocast(): - dataset.cache_text_encoder_outputs( - tokenizers, - text_encoders, - accelerator.device, - weight_dtype, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, + dataset.new_cache_text_encoder_outputs( + text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process ) + accelerator.wait_for_everyone() text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 5df739e28..cbfcef554 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -5,10 +5,10 @@ import torch from library.device_utils import init_ipex -init_ipex() -from library import sdxl_model_util, sdxl_train_util, train_util +init_ipex() +from library import sdxl_model_util, sdxl_train_util, strategy_sd, strategy_sdxl, train_util import train_textual_inversion @@ -41,28 +41,20 @@ def load_target_model(self, args, weight_dtype, accelerator): return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet - def load_tokenizer(self, args): - tokenizer = sdxl_train_util.load_tokenizers(args) - return tokenizer - - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] - with torch.enable_grad(): - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizers[0], - tokenizers[1], - text_encoders[0], - text_encoders[1], - None if not args.full_fp16 else weight_dtype, - accelerator=accelerator, - ) - return encoder_hidden_states1, encoder_hidden_states2, pool2 + def get_tokenize_strategy(self, args): + return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): + return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sdxl.SdxlTextEncodingStrategy() def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -81,9 +73,11 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + def sample_images( + self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement + ): sdxl_train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement ) def save_weights(self, file, updated_embs, save_dtype, metadata): @@ -122,8 +116,7 @@ def load_weights(self, file): def setup_parser() -> argparse.ArgumentParser: parser = train_textual_inversion.setup_parser() - # don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching - # sdxl_train_util.add_sdxl_training_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False) return parser diff --git a/train_db.py b/train_db.py index 39d8ea6ed..7caee6647 100644 --- a/train_db.py +++ b/train_db.py @@ -11,7 +11,7 @@ from tqdm import tqdm import torch -from library import deepspeed_utils +from library import deepspeed_utils, strategy_base from library.device_utils import init_ipex, clean_memory_on_device @@ -38,6 +38,7 @@ apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments +import library.strategy_sd as strategy_sd setup_logging() import logging @@ -58,7 +59,14 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -80,10 +88,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -145,13 +153,17 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # 学習を準備する:モデルを適切な状態にする train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 @@ -184,8 +196,11 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -290,10 +305,16 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs) + accelerator.init_trackers( + "dreambooth" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) # For --sample_at_first - train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -331,7 +352,7 @@ def train(args): with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): if args.weighted_captions: encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, + tokenize_strategy.tokenizer, text_encoder, batch["captions"], accelerator.device, @@ -339,14 +360,18 @@ def train(args): clip_skip=args.clip_skip, ) else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder], [input_ids] + )[0] + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) # Predict the noise residual with accelerator.autocast(): @@ -358,7 +383,9 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -393,7 +420,7 @@ def train(args): global_step += 1 train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) # 指定ステップごとにモデルを保存 @@ -457,7 +484,9 @@ def train(args): vae, ) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) is_main_process = accelerator.is_main_process if is_main_process: diff --git a/train_network.py b/train_network.py index 7ba073855..3828fed19 100644 --- a/train_network.py +++ b/train_network.py @@ -7,6 +7,7 @@ import time import json from multiprocessing import Value +from typing import Any, List import toml from tqdm import tqdm @@ -18,7 +19,7 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import deepspeed_utils, model_util +from library import deepspeed_utils, model_util, strategy_base, strategy_sd import library.train_util as train_util from library.train_util import DreamBoothDataset @@ -101,19 +102,31 @@ def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet - def load_tokenizer(self, args): - tokenizer = train_util.load_tokenizer(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) - def is_text_encoder_outputs_cached(self, args): - return False + def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]: + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, False + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sd.SdTextEncodingStrategy(args.clip_skip) + + def get_text_encoder_outputs_caching_strategy(self, args): + return None + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + return text_encoders def is_train_text_encoder(self, args): - return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args) + return not args.network_train_unet_only - def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype - ): + def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype): for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) @@ -123,7 +136,7 @@ def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, wei return encoder_hidden_states def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - noise_pred = unet(noisy_latents, timesteps, text_conds).sample + noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred def all_reduce_network(self, accelerator, network): @@ -131,8 +144,8 @@ def all_reduce_network(self, accelerator, network): if param.grad is not None: param.grad = accelerator.reduce(param.grad, reduction="mean") - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet): + train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet) def train(self, args): session_id = random.randint(0, 2**32) @@ -150,9 +163,13 @@ def train(self, args): args.seed = random.randint(0, 2**32) set_seed(args.seed) - # tokenizerは単体またはリスト、tokenizersは必ずリスト:既存のコードとの互換性のため - tokenizer = self.load_tokenizer(args) - tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] + tokenize_strategy = self.get_tokenize_strategy(args) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = self.get_latents_caching_strategy(args) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -194,11 +211,11 @@ def train(self, args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -268,8 +285,9 @@ def train(self, args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -277,9 +295,13 @@ def train(self, args): # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu - self.cache_text_encoder_outputs_if_needed( - args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype - ) + text_encoding_strategy = self.get_text_encoding_strategy(args) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args) + if text_encoder_outputs_caching_strategy is not None: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) + self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype) # prepare network net_kwargs = {} @@ -366,7 +388,11 @@ def train(self, args): optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -878,7 +904,7 @@ def remove_model(old_ckpt_name): os.remove(old_ckpt_file) # For --sample_at_first - self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) # training loop if initial_step > 0: # only if skip_until_initial_step is specified @@ -933,21 +959,31 @@ def remove_model(old_ckpt_name): # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): - # Get the text embedding for conditioning - if args.weighted_captions: - text_encoder_conds = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - text_encoder_conds = self.get_text_cond( - args, accelerator, batch, tokenizers, text_encoders, weight_dtype - ) + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + else: + with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + # SD only + text_encoder_conds = get_weighted_text_embeddings( + tokenizers[0], + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids, + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -1026,7 +1062,9 @@ def remove_model(old_ckpt_name): progress_bar.update(1) global_step += 1 - self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet + ) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -1082,7 +1120,7 @@ def remove_model(old_ckpt_name): if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) # end of epoch diff --git a/train_textual_inversion.py b/train_textual_inversion.py index ade077c36..9044f50df 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -2,6 +2,7 @@ import math import os from multiprocessing import Value +from typing import Any, List import toml from tqdm import tqdm @@ -15,7 +16,7 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer -from library import deepspeed_utils, model_util +from library import deepspeed_utils, model_util, strategy_base, strategy_sd import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -103,28 +104,38 @@ def assert_extra_args(self, args, train_dataset_group): def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) - return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet - def load_tokenizer(self, args): - tokenizer = train_util.load_tokenizer(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]: + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, False + ) + return latents_caching_strategy def assert_token_string(self, token_string, tokenizers: CLIPTokenizer): pass - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - with torch.enable_grad(): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None) - return encoder_hidden_states + def get_text_encoding_strategy(self, args): + return strategy_sd.SdTextEncodingStrategy(args.clip_skip) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders) -> List[Any]: + return text_encoders def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - noise_pred = unet(noisy_latents, timesteps, text_conds).sample + noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + def sample_images( + self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement + ): train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoders[0], unet, prompt_replacement ) def save_weights(self, file, updated_embs, save_dtype, metadata): @@ -182,8 +193,13 @@ def train(self, args): if args.seed is not None: set_seed(args.seed) - tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer - tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] + tokenize_strategy = self.get_tokenize_strategy(args) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = self.get_latents_caching_strategy(args) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # acceleratorを準備する logger.info("prepare accelerator") @@ -194,14 +210,7 @@ def train(self, args): vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む - model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator) - text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list - - if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1: - accelerator.print( - "accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / " - + "accelerateでは複数のモデル(テキストエンコーダー)のgradient_accumulation_stepsはサポートされていないようです" - ) + model_version, text_encoders, vae, unet = self.load_target_model(args, weight_dtype, accelerator) # Convert the init_word to token_id init_token_ids_list = [] @@ -310,10 +319,10 @@ def train(self, args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list) + train_dataset_group = train_util.load_arbitrary_dataset(args) self.assert_extra_args(args, train_dataset_group) @@ -368,11 +377,10 @@ def train(self, args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - clean_memory_on_device(accelerator.device) + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() if args.gradient_checkpointing: @@ -387,7 +395,11 @@ def train(self, args): trainable_params += text_encoder.get_input_embeddings().parameters() _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -415,20 +427,8 @@ def train(self, args): lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # acceleratorがなんかよろしくやってくれるらしい - if len(text_encoders) == 1: - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler - ) - - elif len(text_encoders) == 2: - text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler - ) - - text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] - - else: - raise NotImplementedError() + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders] index_no_updates_list = [] orig_embeds_params_list = [] @@ -456,6 +456,9 @@ def train(self, args): else: unet.eval() + text_encoding_strategy = self.get_text_encoding_strategy(args) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) vae.eval() @@ -510,7 +513,9 @@ def train(self, args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) # function for saving/removing @@ -540,8 +545,8 @@ def remove_model(old_ckpt_name): global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) @@ -568,7 +573,12 @@ def remove_model(old_ckpt_name): latents = latents * self.vae_scale_factor # Get the text embedding for conditioning - text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype) + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -588,7 +598,9 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -639,8 +651,8 @@ def remove_model(old_ckpt_name): global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) @@ -722,8 +734,8 @@ def remove_model(old_ckpt_name): global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) From 1a977e847a10975c042c0fdacd871a33c9e93900 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 27 Jul 2024 13:51:50 +0900 Subject: [PATCH 046/348] fix typos --- library/strategy_base.py | 2 +- library/strategy_sd.py | 2 +- library/strategy_sd3.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/library/strategy_base.py b/library/strategy_base.py index 594cca5eb..a99a08290 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -218,7 +218,7 @@ def is_disk_cached_latents_expected( def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): raise NotImplementedError - def _defualt_is_disk_cached_latents_expected( + def _default_is_disk_cached_latents_expected( self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool ): if not self.cache_to_disk: diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 105816145..83ffaa31b 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -125,7 +125,7 @@ def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index 42630ab22..7491e814f 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -177,7 +177,7 @@ def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): From 002d75179ae5a3b165a65c5cf49c00bf8f98e2df Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 29 Jul 2024 23:18:34 +0900 Subject: [PATCH 047/348] sample images for training --- library/sd3_train_utils.py | 348 ++++++++++++++++++++++++++++++++++++- sd3_train.py | 51 +++--- 2 files changed, 367 insertions(+), 32 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 8f99d9474..da0729506 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -1,14 +1,18 @@ import argparse -import glob import math import os -from typing import List, Optional, Tuple, Union +import toml +import json +import time +from typing import Dict, List, Optional, Tuple, Union import torch from safetensors.torch import save_file -from accelerate import Accelerator +from accelerate import Accelerator, PartialState +from tqdm import tqdm +from PIL import Image -from library import sd3_models, sd3_utils, train_util +from library import sd3_models, sd3_utils, strategy_base, train_util from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -276,10 +280,342 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin ) -def sample_images(*args, **kwargs): - return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) +# temporary copied from sd3_minimal_inferece.py +def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): + start = sampling.timestep(sampling.sigma_max) + end = sampling.timestep(sampling.sigma_min) + timesteps = torch.linspace(start, end, steps) + sigs = [] + for x in range(len(timesteps)): + ts = timesteps[x] + sigs.append(sampling.sigma(ts)) + sigs += [0.0] + return torch.FloatTensor(sigs) + + +def max_denoise(model_sampling, sigmas): + max_sigma = float(model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma + + +def do_sample( + height: int, + width: int, + seed: int, + cond: Tuple[torch.Tensor, torch.Tensor], + neg_cond: Tuple[torch.Tensor, torch.Tensor], + mmdit: sd3_models.MMDiT, + steps: int, + guidance_scale: float, + dtype: torch.dtype, + device: str, +): + latent = torch.zeros(1, 16, height // 8, width // 8, device=device) + latent = latent.to(dtype).to(device) + + # noise = get_noise(seed, latent).to(device) + if seed is not None: + generator = torch.manual_seed(seed) + noise = ( + torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu") + .to(latent.dtype) + .to(device) + ) + + model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 + + sigmas = get_sigmas(model_sampling, steps).to(device) + + noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas)) + + c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype) + y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype) + + x = noise_scaled.to(device).to(dtype) + # print(x.shape) + + with torch.no_grad(): + for i in tqdm(range(len(sigmas) - 1)): + sigma_hat = sigmas[i] + + timestep = model_sampling.timestep(sigma_hat).float() + timestep = torch.FloatTensor([timestep, timestep]).to(device) + + x_c_nc = torch.cat([x, x], dim=0) + # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + model_output = model_output.float() + batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) + + pos_out, neg_out = batched.chunk(2) + denoised = neg_out + (pos_out - neg_out) * guidance_scale + # print(denoised.shape) + + # d = to_d(x, sigma_hat, denoised) + dims_to_append = x.ndim - sigma_hat.ndim + sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] + # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) + """Converts a denoiser output to a Karras ODE derivative.""" + d = (x - denoised) / sigma_hat_dims + + dt = sigmas[i + 1] - sigma_hat + + # Euler method + x = x + d * dt + x = x.to(dtype) + + return x + + +def load_prompts(prompt_file: str) -> List[Dict]: + # read prompts + if prompt_file.endswith(".txt"): + with open(prompt_file, "r", encoding="utf-8") as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif prompt_file.endswith(".toml"): + with open(prompt_file, "r", encoding="utf-8") as f: + data = toml.load(f) + prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] + elif prompt_file.endswith(".json"): + with open(prompt_file, "r", encoding="utf-8") as f: + prompts = json.load(f) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from library.train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + return prompts + + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + mmdit, + vae, + text_encoders, + sample_prompts_te_outputs, + prompt_replacement=None, +): + if steps == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts): + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap unet and text_encoder(s) + mmdit = accelerator.unwrap_model(mmdit) + text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = load_prompts(args.sample_prompts) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + org_vae_device = vae.device # will be on cpu + vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, + args, + mmdit, + text_encoders, + vae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, + args, + mmdit, + text_encoders, + vae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + vae.to(org_vae_device) + + clean_memory_on_device(accelerator.device) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + mmdit: sd3_models.MMDiT, + text_encoders: List[Union[sd3_models.SDClipModel, sd3_models.SDXLClipG, sd3_models.T5XXLModel]], + vae: sd3_models.SDVAE, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, +): + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + # controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + if negative_prompt is None: + negative_prompt = "" + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: + te_outputs = sample_prompts_te_outputs[prompt] + else: + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(prompt) + te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) + + lg_out, t5_out, pooled = te_outputs + cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # encode negative prompts + if sample_prompts_te_outputs and negative_prompt in sample_prompts_te_outputs: + neg_te_outputs = sample_prompts_te_outputs[negative_prompt] + else: + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(negative_prompt) + neg_te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) + + lg_out, t5_out, pooled = neg_te_outputs + neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # sample image + latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device) + latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype)) + + # latent to image + with torch.no_grad(): + image = vae.decode(latents) + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) + decoded_np = decoded_np.astype(np.uint8) + + image = Image.fromarray(decoded_np) + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass + # region Diffusers diff --git a/sd3_train.py b/sd3_train.py index 617e30271..2f4ea8cb2 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -299,6 +299,7 @@ def train(args): t5xxl.eval() # cache text encoder outputs + sample_prompts_te_outputs = None if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad here clip_l.to(accelerator.device) @@ -321,6 +322,22 @@ def train(args): with accelerator.autocast(): train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + prompts = sd3_train_utils.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_list = sd3_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_list + ) + accelerator.wait_for_everyone() # load MMDIT @@ -635,10 +652,8 @@ def optimizer_hook(parameter: torch.Tensor): init_kwargs=init_kwargs, ) - # # For --sample_at_first - # sd3_train_utils.sample_images( - # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [clip_l, clip_g], mmdit - # ) + # For --sample_at_first + sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs) # following function will be moved to sd3_train_utils @@ -831,17 +846,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): progress_bar.update(1) global_step += 1 - # sdxl_train_util.sample_images( - # accelerator, - # args, - # None, - # global_step, - # accelerator.device, - # vae, - # [tokenizer1, tokenizer2], - # [clip_l, clip_g], - # mmdit, - # ) + sd3_train_utils.sample_images( + accelerator, args, None, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs + ) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -900,17 +907,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): vae, ) - # sdxl_train_util.sample_images( - # accelerator, - # args, - # epoch + 1, - # global_step, - # accelerator.device, - # vae, - # [tokenizer1, tokenizer2], - # [clip_l, clip_g], - # mmdit, - # ) + sd3_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs + ) is_main_process = accelerator.is_main_process # if is_main_process: From 31507b9901d1d9ab65ba79ebd747b7f35c7e0fc1 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:15:21 +0800 Subject: [PATCH 048/348] Remove unnecessary is_train changes and use apply_debiased_estimation to calculate validation loss. Balances the influence of different time steps on training performance (without affecting actual training results) --- train_network.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/train_network.py b/train_network.py index 2a3a44824..4a5940cd5 100644 --- a/train_network.py +++ b/train_network.py @@ -135,7 +135,7 @@ def all_reduce_network(self, accelerator, network): def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) - def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): + def process_val_batch(self, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): total_loss = 0.0 timesteps_list = [10, 350, 500, 650, 990] @@ -153,7 +153,7 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va latents = latents * self.vae_scale_factor b_size = latents.shape[0] - with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(False), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings( @@ -173,7 +173,7 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va # with noise offset and/or multires noise if specified for fixed_timesteps in timesteps_list: - with torch.set_grad_enabled(is_train), accelerator.autocast(): + with torch.set_grad_enabled(False), accelerator.autocast(): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) @@ -189,6 +189,7 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし total_loss += loss @@ -885,8 +886,7 @@ def remove_model(old_ckpt_name): for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_model): - on_step_start(text_encoder, unet) - is_train = True + on_step_start(text_encoder, unet) if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: @@ -911,7 +911,7 @@ def remove_model(old_ckpt_name): # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) - with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings( @@ -941,7 +941,7 @@ def remove_model(old_ckpt_name): t.requires_grad_(True) # Predict the noise residual - with torch.set_grad_enabled(is_train), accelerator.autocast(): + with accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, @@ -1040,10 +1040,9 @@ def remove_model(old_ckpt_name): total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - is_train = False + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) total_loss += loss.detach().item() current_loss = total_loss / validation_steps val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) From 1db495127f25c1b17694780f635a4760b4e345d0 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 14:53:46 +0800 Subject: [PATCH 049/348] Update train_db.py --- train_db.py | 132 +++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 126 insertions(+), 6 deletions(-) diff --git a/train_db.py b/train_db.py index 1de504ed8..9f8ec777c 100644 --- a/train_db.py +++ b/train_db.py @@ -2,7 +2,6 @@ # XXX dropped option: fine_tune import argparse -import itertools import math import os from multiprocessing import Value @@ -41,11 +40,73 @@ setup_logging() import logging +import itertools logger = logging.getLogger(__name__) # perlin_noise, - +def process_val_batch(*training_models, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args): + total_loss = 0.0 + timesteps_list = [10, 350, 500, 650, 990] + + with accelerator.accumulate(*training_models): + with torch.no_grad(): + # latentに変換 + if cache_latents: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + with torch.set_grad_enabled(False), accelerator.autocast(): + if args.weighted_captions: + encoder_hidden_states = get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + + for fixed_timesteps in timesteps_list: + with torch.set_grad_enabled(False), accelerator.autocast(): + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] + timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + total_loss += loss + + average_loss = total_loss / len(timesteps_list) + return average_loss def train(args): train_util.verify_training_args(args) @@ -81,9 +142,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -148,6 +210,9 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") + if val_dataset_group is not None: + print("Cache validation latents...") + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -195,6 +260,15 @@ def train(args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + cyclic_val_dataloader = itertools.cycle(val_dataloader) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -296,6 +370,8 @@ def train(args): train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) loss_recorder = train_util.LossRecorder() + val_loss_recorder = train_util.LossRecorder() + for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -427,12 +503,33 @@ def train(args): avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - + + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} + logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() @@ -515,7 +612,30 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) - + parser.add_argument( + "--validation_seed", + type=int, + default=None, + help="Validation seed" + ) + parser.add_argument( + "--validation_split", + type=float, + default=0.0, + help="Split for validation images out of the training dataset" + ) + parser.add_argument( + "--validation_every_n_step", + type=int, + default=None, + help="Number of train steps for counting validation loss. By default, validation per train epoch is performed" + ) + parser.add_argument( + "--max_validation_steps", + type=int, + default=None, + help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset" + ) return parser From 68162172ebf9afa21ad526fc833fcc04f74aeb5f Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:03:56 +0800 Subject: [PATCH 050/348] Update train_db.py --- train_db.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_db.py b/train_db.py index 9f8ec777c..e98434dba 100644 --- a/train_db.py +++ b/train_db.py @@ -209,10 +209,10 @@ def train(args): vae.eval() with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") if val_dataset_group is not None: print("Cache validation latents...") - val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() From 96eb74f0cba3253ba29c8e87d7479c355916cca5 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:06:05 +0800 Subject: [PATCH 051/348] Update train_db.py --- train_db.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_db.py b/train_db.py index e98434dba..80fdff3e7 100644 --- a/train_db.py +++ b/train_db.py @@ -210,8 +210,8 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) if val_dataset_group is not None: - print("Cache validation latents...") - val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + print("Cache validation latents...") + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") clean_memory_on_device(accelerator.device) From b9bdd101296b8dc3c60b25e31d04d39b57eaee71 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:11:26 +0800 Subject: [PATCH 052/348] Update train_network.py --- train_network.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/train_network.py b/train_network.py index 4a5940cd5..d7b24dae9 100644 --- a/train_network.py +++ b/train_network.py @@ -1034,25 +1034,25 @@ def remove_model(old_ckpt_name): logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break From 3d68754defde57b10f96d9c934dd78bf25c39235 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:15:42 +0800 Subject: [PATCH 053/348] Update train_db.py --- train_db.py | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/train_db.py b/train_db.py index 80fdff3e7..800a157bf 100644 --- a/train_db.py +++ b/train_db.py @@ -503,28 +503,26 @@ def train(args): avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if len(val_dataloader) > 0: if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) - - + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break From a593e837f36b6299101dc85a367c0986501ecc0a Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:17:30 +0800 Subject: [PATCH 054/348] Update train_network.py --- train_network.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/train_network.py b/train_network.py index d7b24dae9..7d9134638 100644 --- a/train_network.py +++ b/train_network.py @@ -1034,26 +1034,26 @@ def remove_model(old_ckpt_name): logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) - + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break From f6dbf7c419bbcf2e51c82a6bffa8d30cad2e3512 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:18:53 +0800 Subject: [PATCH 055/348] Update train_network.py --- train_network.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/train_network.py b/train_network.py index 7d9134638..fa6407eef 100644 --- a/train_network.py +++ b/train_network.py @@ -1034,26 +1034,26 @@ def remove_model(old_ckpt_name): logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) - + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break From aa850aa531b0e396b6f2fbd68cd1e6f1319d1d0b Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:34:20 +0800 Subject: [PATCH 056/348] Update train_network.py --- train_network.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/train_network.py b/train_network.py index fa6407eef..938e41938 100644 --- a/train_network.py +++ b/train_network.py @@ -1034,25 +1034,25 @@ def remove_model(old_ckpt_name): logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break From cdb2d9c516fbffe0faa9788b8174e5d418fb766b Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:36:34 +0800 Subject: [PATCH 057/348] Update train_network.py --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 938e41938..e10c17c0c 100644 --- a/train_network.py +++ b/train_network.py @@ -192,7 +192,7 @@ def process_val_batch(self, batch, tokenizers, text_encoders, unet, vae, noise_s loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし total_loss += loss - + average_loss = total_loss / len(timesteps_list) return average_loss From 231df197ddf4372b3d90751146927f33e1965d1a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 5 Aug 2024 20:26:30 +0900 Subject: [PATCH 058/348] Fix npz path for verification --- library/strategy_sdxl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index a4513336d..3eb0ab6f6 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -184,20 +184,20 @@ def __init__( def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX - def is_disk_cached_outputs_expected(self, abs_path: str): + def is_disk_cached_outputs_expected(self, npz_path: str): if not self.cache_to_disk: return False - if not os.path.exists(self.get_outputs_npz_path(abs_path)): + if not os.path.exists(npz_path): return False if self.skip_disk_cache_validity_check: return True try: - npz = np.load(self.get_outputs_npz_path(abs_path)) + npz = np.load(npz_path) if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz: return False except Exception as e: - logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}") + logger.error(f"Error loading file: {npz_path}") raise e return True From da4d0fe0165b3e0143c237de8cf307d53a9de45a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 5 Aug 2024 20:51:34 +0900 Subject: [PATCH 059/348] support attn mask for l+g/t5 --- library/strategy_sd3.py | 88 +++++++++++++++++++++++++++++++++------- library/train_util.py | 3 +- sd3_minimal_inference.py | 10 +++-- sd3_train.py | 30 +++++++++++--- 4 files changed, 107 insertions(+), 24 deletions(-) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index 7491e814f..a22818903 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -37,11 +37,14 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + l_attn_mask = l_tokens["attention_mask"] + g_attn_mask = g_tokens["attention_mask"] + t5_attn_mask = t5_tokens["attention_mask"] l_tokens = l_tokens["input_ids"] g_tokens = g_tokens["input_ids"] t5_tokens = t5_tokens["input_ids"] - return [l_tokens, g_tokens, t5_tokens] + return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask] class Sd3TextEncodingStrategy(TextEncodingStrategy): @@ -49,11 +52,20 @@ def __init__(self) -> None: pass def encode_tokens( - self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: List[torch.Tensor], + apply_lg_attn_mask: bool = False, + apply_t5_attn_mask: bool = False, ) -> List[torch.Tensor]: + """ + returned embeddings are not masked + """ clip_l, clip_g, t5xxl = models - l_tokens, g_tokens, t5_tokens = tokens + l_tokens, g_tokens, t5_tokens = tokens[:3] + l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] if len(tokens) > 3 else [None, None, None] if l_tokens is None: assert g_tokens is None, "g_tokens must be None if l_tokens is None" lg_out = None @@ -61,10 +73,15 @@ def encode_tokens( assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" l_out, l_pooled = clip_l(l_tokens) g_out, g_pooled = clip_g(g_tokens) + if apply_lg_attn_mask: + l_out = l_out * l_attn_mask.to(l_out.device).unsqueeze(-1) + g_out = g_out * g_attn_mask.to(g_out.device).unsqueeze(-1) lg_out = torch.cat([l_out, g_out], dim=-1) if t5xxl is not None and t5_tokens is not None: t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096] + if apply_t5_attn_mask: + t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) else: t5_out = None @@ -84,50 +101,81 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz" def __init__( - self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + apply_lg_attn_mask: bool = False, + apply_t5_attn_mask: bool = False, ) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + self.apply_lg_attn_mask = apply_lg_attn_mask + self.apply_t5_attn_mask = apply_t5_attn_mask def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX - def is_disk_cached_outputs_expected(self, abs_path: str): + def is_disk_cached_outputs_expected(self, npz_path: str): if not self.cache_to_disk: return False - if not os.path.exists(self.get_outputs_npz_path(abs_path)): + if not os.path.exists(npz_path): return False if self.skip_disk_cache_validity_check: return True try: - npz = np.load(self.get_outputs_npz_path(abs_path)) - if "clip_l" not in npz or "clip_g" not in npz: + npz = np.load(npz_path) + if "lg_out" not in npz: return False - if "clip_l_pool" not in npz or "clip_g_pool" not in npz: + if "lg_pooled" not in npz: + return False + if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used return False # t5xxl is optional except Exception as e: - logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}") + logger.error(f"Error loading file: {npz_path}") raise e return True + def mask_lg_attn(self, lg_out: np.ndarray, l_attn_mask: np.ndarray, g_attn_mask: np.ndarray) -> np.ndarray: + l_out = lg_out[..., :768] + g_out = lg_out[..., 768:] # 1280 + l_out = l_out * np.expand_dims(l_attn_mask, -1) # l_out = l_out * l_attn_mask. + g_out = g_out * np.expand_dims(g_attn_mask, -1) # g_out = g_out * g_attn_mask. + return np.concatenate([l_out, g_out], axis=-1) + + def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray: + return t5_out * np.expand_dims(t5_attn_mask, -1) + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: data = np.load(npz_path) lg_out = data["lg_out"] lg_pooled = data["lg_pooled"] t5_out = data["t5_out"] if "t5_out" in data else None + + if self.apply_lg_attn_mask: + l_attn_mask = data["clip_l_attn_mask"] + g_attn_mask = data["clip_g_attn_mask"] + lg_out = self.mask_lg_attn(lg_out, l_attn_mask, g_attn_mask) + + if self.apply_t5_attn_mask and t5_out is not None: + t5_attn_mask = data["t5_attn_mask"] + t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) + return [lg_out, t5_out, lg_pooled] def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List ): + sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy captions = [info.caption for info in infos] - clip_l_tokens, clip_g_tokens, t5xxl_tokens = tokenize_strategy.tokenize(captions) + tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): - lg_out, t5_out, lg_pooled = text_encoding_strategy.encode_tokens( - tokenize_strategy, models, [clip_l_tokens, clip_g_tokens, t5xxl_tokens] + lg_out, t5_out, lg_pooled = sd3_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask ) if lg_out.dtype == torch.bfloat16: @@ -148,10 +196,22 @@ def cache_batch_outputs( lg_pooled_i = lg_pooled[i] if self.cache_to_disk: + clip_l_attn_mask, clip_g_attn_mask, t5_attn_mask = tokens_and_masks[3:6] + clip_l_attn_mask_i = clip_l_attn_mask[i].cpu().numpy() + clip_g_attn_mask_i = clip_g_attn_mask[i].cpu().numpy() + t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() if t5_attn_mask is not None else None # shouldn't be None kwargs = {} if t5_out is not None: kwargs["t5_out"] = t5_out_i - np.savez(info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, **kwargs) + np.savez( + info.text_encoder_outputs_npz, + lg_out=lg_out_i, + lg_pooled=lg_pooled_i, + clip_l_attn_mask=clip_l_attn_mask_i, + clip_g_attn_mask=clip_g_attn_mask_i, + t5_attn_mask=t5_attn_mask_i, + **kwargs, + ) else: info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i) diff --git a/library/train_util.py b/library/train_util.py index a747e0478..fc458a884 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -646,7 +646,7 @@ def __init__( # caching self.caching_mode = None # None, 'latents', 'text' - + self.tokenize_strategy = None self.text_encoder_output_caching_strategy = None self.latents_caching_strategy = None @@ -1486,6 +1486,7 @@ def __getitem__(self, index): text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( image_info.text_encoder_outputs_npz ) + text_encoder_outputs = [torch.FloatTensor(x) for x in text_encoder_outputs] else: tokenization_required = True text_encoder_outputs_list.append(text_encoder_outputs) diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index e9e61af1b..630da7e08 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -146,6 +146,8 @@ def do_sample( parser.add_argument("--clip_l", type=str, required=False) parser.add_argument("--t5xxl", type=str, required=False) parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77") + parser.add_argument("--apply_lg_attn_mask", action="store_true") + parser.add_argument("--apply_t5_attn_mask", action="store_true") parser.add_argument("--prompt", type=str, default="A photo of a cat") # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders parser.add_argument("--negative_prompt", type=str, default="") @@ -323,15 +325,15 @@ def do_sample( logger.info("Encoding prompts...") encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() - l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.prompt) + tokens_and_masks = tokenize_strategy.tokenize(args.prompt) lg_out, t5_out, pooled = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens] + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask ) cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) - l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.negative_prompt) + tokens_and_masks = tokenize_strategy.tokenize(args.negative_prompt) lg_out, t5_out, pooled = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens] + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask ) neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) diff --git a/sd3_train.py b/sd3_train.py index 2f4ea8cb2..9c37cbce6 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -172,6 +172,8 @@ def train(args): args.text_encoder_batch_size, False, False, + False, + False, ) ) train_dataset_group.set_current_strategies() @@ -312,6 +314,8 @@ def train(args): args.text_encoder_batch_size, False, train_clip_g or train_clip_l or args.use_t5xxl_cache_only, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) @@ -335,7 +339,11 @@ def train(args): logger.info(f"cache Text Encoder outputs for prompt: {p}") tokens_list = sd3_tokenize_strategy.tokenize(p) sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - sd3_tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_list + sd3_tokenize_strategy, + [clip_l, clip_g, t5xxl], + tokens_list, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, ) accelerator.wait_for_everyone() @@ -748,21 +756,23 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): if lg_out is None or (train_clip_l or train_clip_g): # not cached or training, so get from text encoders - input_ids_clip_l, input_ids_clip_g, _ = batch["input_ids_list"] + input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # TODO support weighted captions input_ids_clip_l = input_ids_clip_l.to(accelerator.device) input_ids_clip_g = input_ids_clip_g.to(accelerator.device) lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens( - sd3_tokenize_strategy, [clip_l, clip_g, None], [input_ids_clip_l, input_ids_clip_g, None] + sd3_tokenize_strategy, + [clip_l, clip_g, None], + [input_ids_clip_l, input_ids_clip_g, None, l_attn_mask, g_attn_mask, None], ) if t5_out is None: - _, _, input_ids_t5xxl = batch["input_ids_list"] + _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] with torch.no_grad(): input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None _, t5_out, _ = text_encoding_strategy.encode_tokens( - sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl] + sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] ) context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) @@ -969,6 +979,16 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256", ) + parser.add_argument( + "--apply_lg_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + ) # TE training is disabled temporarily # parser.add_argument( From 36b2e6fc288c57f496a061e4d638f5641c32c9ea Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 9 Aug 2024 22:56:48 +0900 Subject: [PATCH 060/348] add FLUX.1 LoRA training --- README.md | 20 + flux_minimal_inference.py | 390 ++++++++++++++++ flux_train_network.py | 332 ++++++++++++++ library/flux_models.py | 920 ++++++++++++++++++++++++++++++++++++++ library/flux_utils.py | 215 +++++++++ library/sd3_models.py | 22 +- library/strategy_flux.py | 244 ++++++++++ networks/lora_flux.py | 730 ++++++++++++++++++++++++++++++ sdxl_train_network.py | 5 + train_network.py | 169 ++++--- 10 files changed, 2992 insertions(+), 55 deletions(-) create mode 100644 flux_minimal_inference.py create mode 100644 flux_train_network.py create mode 100644 library/flux_models.py create mode 100644 library/flux_utils.py create mode 100644 library/strategy_flux.py create mode 100644 networks/lora_flux.py diff --git a/README.md b/README.md index d406fecde..a0b02f108 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,25 @@ This repository contains training, generation and utility scripts for Stable Diffusion. +## FLUX.1 LoRA training (WIP) + +__Aug 9, 2024__: + +Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. + +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. + +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name +``` + +The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. + +``` +python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors +``` + +Unfortnately the training result is not good. Please let us know if you have any idea to improve the training. + ## SD3 training SD3 training is done with `sd3_train.py`. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py new file mode 100644 index 000000000..f3affca80 --- /dev/null +++ b/flux_minimal_inference.py @@ -0,0 +1,390 @@ +# Minimum Inference Code for FLUX + +import argparse +import datetime +import math +import os +import random +from typing import Callable, Optional, Tuple +import einops +import numpy as np + +import torch +from safetensors.torch import safe_open, load_file +from tqdm import tqdm +from PIL import Image +import accelerate + +from library import device_utils +from library.device_utils import init_ipex, get_preferred_device + +init_ipex() + + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import networks.lora_flux as lora_flux +from library import flux_models, flux_utils, sd3_utils, strategy_flux + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + vec: torch.Tensor, + timesteps: list[float], + guidance: float = 4.0, +): + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec) + + img = img + (t_prev - t_curr) * pred + + return img + + +def do_sample( + accelerator: Optional[accelerate.Accelerator], + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + l_pooled: torch.Tensor, + t5_out: torch.Tensor, + txt_ids: torch.Tensor, + num_steps: int, + guidance: float, + is_schnell: bool, + device: torch.device, + flux_dtype: torch.dtype, +): + timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) + + # denoise initial noise + if accelerator: + with accelerator.autocast(), torch.no_grad(): + x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance) + else: + with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad(): + x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance) + + return x + + +def generate_image( + model, + clip_l, + t5xxl, + ae, + prompt: str, + seed: Optional[int], + image_width: int, + image_height: int, + steps: Optional[int], + guidance: float, +): + # make first noise with packed shape + # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 + packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) + noise = torch.randn( + 1, + packed_latent_height * packed_latent_width, + 16 * 2 * 2, + device=device, + dtype=dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + # prepare img and img ids + + # this is needed only for img2img + # img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + # if img.shape[0] == 1 and bs > 1: + # img = repeat(img, "1 ... -> bs ...", bs=bs) + + # txt2img only needs img_ids + img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) + + # prepare embeddings + logger.info("Encoding prompts...") + tokens_and_masks = tokenize_strategy.tokenize(prompt) + clip_l = clip_l.to(device) + t5xxl = t5xxl.to(device) + with torch.no_grad(): + if is_fp8(clip_l_dtype) or is_fp8(t5xxl_dtype): + clip_l.to(clip_l_dtype) + t5xxl.to(t5xxl_dtype) + with accelerator.autocast(): + _, t5_out, txt_ids = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + else: + with torch.autocast(device_type=device.type, dtype=clip_l_dtype): + l_pooled, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): + _, t5_out, txt_ids = encoding_strategy.encode_tokens( + tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + + # NaN check + if torch.isnan(l_pooled).any(): + raise ValueError("NaN in l_pooled") + if torch.isnan(t5_out).any(): + raise ValueError("NaN in t5_out") + + if args.offload: + clip_l = clip_l.cpu() + t5xxl = t5xxl.cpu() + # del clip_l, t5xxl + device_utils.clean_memory() + + # generate image + logger.info("Generating image...") + model = model.to(device) + if steps is None: + steps = 4 if is_schnell else 50 + + img_ids = img_ids.to(device) + x = do_sample( + accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance_scale, is_schnell, device, flux_dtype + ) + if args.offload: + model = model.cpu() + # del model + device_utils.clean_memory() + + # unpack + x = x.float() + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + + # decode + logger.info("Decoding image...") + ae = ae.to(device) + with torch.no_grad(): + if is_fp8(ae_dtype): + with accelerator.autocast(): + x = ae.decode(x) + else: + with torch.autocast(device_type=device.type, dtype=ae_dtype): + x = ae.decode(x) + if args.offload: + ae = ae.cpu() + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + + # save image + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") + img.save(output_path) + + logger.info(f"Saved image to {output_path}") + + +if __name__ == "__main__": + target_height = 768 # 1024 + target_width = 1360 # 1024 + + # steps = 50 # 28 # 50 + # guidance_scale = 5 + # seed = 1 # None # 1 + + device = get_preferred_device() + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--clip_l", type=str, required=False) + parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--ae", type=str, required=False) + parser.add_argument("--apply_t5_attn_mask", action="store_true") + parser.add_argument("--prompt", type=str, default="A photo of a cat") + parser.add_argument("--output_dir", type=str, default=".") + parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype") + parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l") + parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae") + parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl") + parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux") + parser.add_argument("--seed", type=int, default=None) + parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev") + parser.add_argument("--guidance", type=float, default=3.5) + parser.add_argument("--offload", action="store_true", help="Offload to CPU") + parser.add_argument( + "--lora_weights", + type=str, + nargs="*", + default=[], + help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)", + ) + parser.add_argument("--width", type=int, default=target_width) + parser.add_argument("--height", type=int, default=target_height) + parser.add_argument("--interactive", action="store_true") + args = parser.parse_args() + + seed = args.seed + steps = args.steps + guidance_scale = args.guidance + + name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way + is_schnell = name == "schnell" + + def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: + if s is None: + return default_dtype + if s in ["bf16", "bfloat16"]: + return torch.bfloat16 + elif s in ["fp16", "float16"]: + return torch.float16 + elif s in ["fp32", "float32"]: + return torch.float32 + elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: + return torch.float8_e4m3fn + elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: + return torch.float8_e4m3fnuz + elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: + return torch.float8_e5m2 + elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: + return torch.float8_e5m2fnuz + elif s in ["fp8", "float8"]: + return torch.float8_e4m3fn # default fp8 + else: + raise ValueError(f"Unsupported dtype: {s}") + + def is_fp8(dt): + return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] + + dtype = str_to_dtype(args.dtype) + clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype) + t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype) + ae_dtype = str_to_dtype(args.ae_dtype, dtype) + flux_dtype = str_to_dtype(args.flux_dtype, dtype) + + logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}") + + loading_device = "cpu" if args.offload else device + + use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]] + if any(use_fp8): + accelerator = accelerate.Accelerator(mixed_precision="bf16") + else: + accelerator = None + + # load clip_l + logger.info(f"Loading clip_l from {args.clip_l}...") + clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device) + clip_l.eval() + + logger.info(f"Loading t5xxl from {args.t5xxl}...") + t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device) + t5xxl.eval() + + if is_fp8(clip_l_dtype): + clip_l = accelerator.prepare(clip_l) + if is_fp8(t5xxl_dtype): + t5xxl = accelerator.prepare(t5xxl) + + t5xxl_max_length = 256 if is_schnell else 512 + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) + encoding_strategy = strategy_flux.FluxTextEncodingStrategy() + + # DiT + model = flux_utils.load_flow_model(name, args.ckpt_path, flux_dtype, loading_device) + model.eval() + logger.info(f"Casting model to {flux_dtype}") + model.to(flux_dtype) # make sure model is dtype + if is_fp8(flux_dtype): + model = accelerator.prepare(model) + + # AE + ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device) + ae.eval() + if is_fp8(ae_dtype): + ae = accelerator.prepare(ae) + + # LoRA + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + lora_model, weights_sd = lora_flux.create_network_from_weights( + multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True + ) + lora_model.merge_to([clip_l, t5xxl], model, weights_sd) + + if not args.interactive: + generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) + else: + # loop for interactive + width = target_width + height = target_height + steps = None + guidance = args.guidance + + while True: + print("Enter prompt (empty to exit). Options: --w --h --s --d --g ") + prompt = input() + if prompt == "": + break + + # parse options + options = prompt.split("--") + prompt = options[0].strip() + seed = None + for opt in options[1:]: + opt = opt.strip() + if opt.startswith("w"): + width = int(opt[1:].strip()) + elif opt.startswith("h"): + height = int(opt[1:].strip()) + elif opt.startswith("s"): + steps = int(opt[1:].strip()) + elif opt.startswith("d"): + seed = int(opt[1:].strip()) + elif opt.startswith("g"): + guidance = float(opt[1:].strip()) + + generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance) + + logger.info("Done!") diff --git a/flux_train_network.py b/flux_train_network.py new file mode 100644 index 000000000..7c762c86d --- /dev/null +++ b/flux_train_network.py @@ -0,0 +1,332 @@ +import argparse +import copy +import math +import random +from typing import Any + +import torch +from accelerate import Accelerator +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from library import flux_models, flux_utils, sd3_train_utils, sd3_utils, sdxl_model_util, sdxl_train_util, strategy_flux, train_util +import train_network +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class FluxNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + assert ( + args.network_train_unet_only or not args.cache_text_encoder_outputs + ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + + train_dataset_group.verify_bucket_reso_steps(32) + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") + clip_l.eval() + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") + t5xxl.eval() + + name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way + # if we load to cpu, flux.to(fp8) takes a long time + model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + + def get_tokenize_strategy(self, args): + return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_flux.FluxTextEncodingStrategy() + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + return strategy_flux.FluxTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) + accelerator.wait_for_everyone() + + logger.info("move text encoders back to cpu") + text_encoders[0].to("cpu") # , dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU + text_encoders[1].to("cpu") # , dtype=torch.float32) + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + + # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # # get size embeddings + # orig_size = batch["original_sizes_hw"] + # crop_size = batch["crop_top_lefts"] + # target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # # concat embeddings + # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + # return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): + # logger.warning("Sampling images is not supported for Flux model") + pass + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images).latent_dist.sample() + + def shift_scale_latents(self, args, latents): + return latents + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + ): + # copy from sd3_train.py and modified + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = self.noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = self.noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None + ): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) + + # Add noise according to flow matching. + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance + guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + img_ids.requires_grad_(True) + guidance_vec.requires_grad_(True) + + # Predict the noise residual + l_pooled, t5_out, txt_ids = text_encoder_conds + # print( + # f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}" + # ) + + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + return model_pred, target, timesteps, None, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + # sdxl_train_util.add_sdxl_training_arguments(parser) + parser.add_argument("--clip_l", type=str, help="path to clip_l") + parser.add_argument("--t5xxl", type=str, help="path to t5xxl") + parser.add_argument("--ae", type=str, help="path to ae") + parser.add_argument("--apply_t5_attn_mask", action="store_true") + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + + # copy from Diffusers + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = FluxNetworkTrainer() + trainer.train(args) diff --git a/library/flux_models.py b/library/flux_models.py new file mode 100644 index 000000000..d0955e375 --- /dev/null +++ b/library/flux_models.py @@ -0,0 +1,920 @@ +# copy from FLUX repo: https://github.com/black-forest-labs/flux +# license: Apache-2.0 License + + +from dataclasses import dataclass +import math + +import torch +from einops import rearrange +from torch import Tensor, nn +from torch.utils.checkpoint import checkpoint + +# USE_REENTRANT = True + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +# region autoencoder + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) + + +# endregion +# region config + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + # repo_id: str | None + # repo_flow: str | None + # repo_ae: str | None + + +configs = { + "dev": ModelSpec( + # repo_id="black-forest-labs/FLUX.1-dev", + # repo_flow="flux1-dev.sft", + # repo_ae="ae.sft", + ckpt_path=None, # os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=None, # os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "schnell": ModelSpec( + # repo_id="black-forest-labs/FLUX.1-schnell", + # repo_flow="flux1-schnell.sft", + # repo_ae="ae.sft", + ckpt_path=None, # os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=None, # os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +# endregion + +# region math + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +# endregion + + +# region layers +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + # def forward(self, x): + # if self.training and self.gradient_checkpointing: + # def create_custom_forward(func): + # def custom_forward(*inputs): + # return func(*inputs) + # return custom_forward + # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, use_reentrant=USE_REENTRANT) + # else: + # return self._forward(x) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + # return (x * rrms).to(dtype=x_dtype) * self.scale + return ((x * rrms) * self.scale.float()).to(dtype=x_dtype) + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + # self.gradient_checkpointing = False + + # def enable_gradient_checkpointing(self): + # self.gradient_checkpointing = True + + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + # def forward(self, *args, **kwargs): + # if self.training and self.gradient_checkpointing: + # return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + # else: + # return self._forward(*args, **kwargs) + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + # self.img_attn.enable_gradient_checkpointing() + # self.txt_attn.enable_gradient_checkpointing() + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + # self.img_attn.disable_gradient_checkpointing() + # self.txt_attn.disable_gradient_checkpointing() + + def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + return img, txt + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + # def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): + # if self.training and self.gradient_checkpointing: + # def create_custom_forward(func): + # def custom_forward(*inputs): + # return func(*inputs) + # return custom_forward + # return torch.utils.checkpoint.checkpoint( + # create_custom_forward(self._forward), img, txt, vec, pe, use_reentrant=USE_REENTRANT + # ) + # else: + # return self._forward(img, txt, vec, pe) + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + # def forward(self, x: Tensor, vec: Tensor, pe: Tensor): + # if self.training and self.gradient_checkpointing: + # def create_custom_forward(func): + # def custom_forward(*inputs): + # return func(*inputs) + # return custom_forward + # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=USE_REENTRANT) + # else: + # return self._forward(x, vec, pe) + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +# endregion + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/library/flux_utils.py b/library/flux_utils.py new file mode 100644 index 000000000..ba828d508 --- /dev/null +++ b/library/flux_utils.py @@ -0,0 +1,215 @@ +import json +from typing import Union +import einops +import torch + +from safetensors.torch import load_file +from accelerate import init_empty_weights +from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config + +from library import flux_models + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +MODEL_VERSION_FLUX_V1 = "flux1" + + +def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux: + logger.info(f"Bulding Flux model {name}") + with torch.device("meta"): + model = flux_models.Flux(flux_models.configs[name].params).to(dtype) + + # load_sft doesn't support torch.device + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_file(ckpt_path, device=str(device)) + info = model.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded Flux: {info}") + return model + + +def load_ae(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.AutoEncoder: + logger.info("Building AutoEncoder") + with torch.device("meta"): + ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_file(ckpt_path, device=str(device)) + info = ae.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded AE: {info}") + return ae + + +def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> CLIPTextModel: + logger.info("Building CLIP") + CLIPL_CONFIG = { + "_name_or_path": "clip-vit-large-patch14/", + "architectures": ["CLIPModel"], + "initializer_factor": 1.0, + "logit_scale_init_value": 2.6592, + "model_type": "clip", + "projection_dim": 768, + # "text_config": { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": None, + "attention_dropout": 0.0, + "bad_words_ids": None, + "bos_token_id": 0, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": False, + "dropout": 0.0, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": 2, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": "quick_gelu", + "hidden_size": 768, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 3072, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "layer_norm_eps": 1e-05, + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 77, + "min_length": 0, + "model_type": "clip_text_model", + "no_repeat_ngram_size": 0, + "num_attention_heads": 12, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 12, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": 1, + "prefix": None, + "problem_type": None, + "projection_dim": 768, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "sep_token_id": None, + "task_specific_params": None, + "temperature": 1.0, + "tie_encoder_decoder": False, + "tie_word_embeddings": True, + "tokenizer_class": None, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": None, + "torchscript": False, + "transformers_version": "4.16.0.dev0", + "use_bfloat16": False, + "vocab_size": 49408, + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, + # }, + # "text_config_dict": { + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "projection_dim": 768, + # }, + # "torch_dtype": "float32", + # "transformers_version": None, + } + config = CLIPConfig(**CLIPL_CONFIG) + with init_empty_weights(): + clip = CLIPTextModel._from_config(config) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_file(ckpt_path, device=str(device)) + info = clip.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded CLIP: {info}") + return clip + + +def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> T5EncoderModel: + T5_CONFIG_JSON = """ +{ + "architectures": [ + "T5EncoderModel" + ], + "classifier_dropout": 0.0, + "d_ff": 10240, + "d_kv": 64, + "d_model": 4096, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 64, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.41.2", + "use_cache": true, + "vocab_size": 32128 +} +""" + config = json.loads(T5_CONFIG_JSON) + config = T5Config(**config) + with init_empty_weights(): + t5xxl = T5EncoderModel._from_config(config) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_file(ckpt_path, device=str(device)) + info = t5xxl.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded T5xxl: {info}") + return t5xxl + + +def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): + img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :] + img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + return img_ids + + +def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: + """ + x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + return x + + +def pack_latents(x: torch.Tensor) -> torch.Tensor: + """ + x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + return x diff --git a/library/sd3_models.py b/library/sd3_models.py index 28378c73b..ec704dcba 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -15,6 +15,12 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from transformers import CLIPTokenizer, T5TokenizerFast +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) memory_efficient_attention = None @@ -95,7 +101,9 @@ def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) # truncate to max_length - print(f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}") + print( + f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}" + ) if truncate_to_max_length and len(batch) > self.max_length: batch = batch[: self.max_length] if truncate_length is not None and len(batch) > truncate_length: @@ -1554,6 +1562,17 @@ def __init__( self.set_clip_options({"layer": layer_idx}) self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def gradient_checkpointing_enable(self): + logger.warning("Gradient checkpointing is not supported for this model") + def set_attn_mode(self, mode): raise NotImplementedError("This model does not support setting the attention mode") @@ -1925,6 +1944,7 @@ def create_clip_l(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[s return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG, ) + clip_l.gradient_checkpointing_enable() if state_dict is not None: # update state_dict if provided to include logit_scale and text_projection.weight avoid errors if "logit_scale" not in state_dict: diff --git a/library/strategy_flux.py b/library/strategy_flux.py new file mode 100644 index 000000000..f194ccf6e --- /dev/null +++ b/library/strategy_flux.py @@ -0,0 +1,244 @@ +import os +import glob +from typing import Any, List, Optional, Tuple, Union +import torch +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast + +from library import sd3_utils, train_util +from library import sd3_models +from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14" +T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" + + +class FluxTokenizeStrategy(TokenizeStrategy): + def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None: + self.t5xxl_max_length = t5xxl_max_length + self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + + l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + + t5_attn_mask = t5_tokens["attention_mask"] + l_tokens = l_tokens["input_ids"] + t5_tokens = t5_tokens["input_ids"] + + return [l_tokens, t5_tokens, t5_attn_mask] + + +class FluxTextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + pass + + def encode_tokens( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: List[torch.Tensor], + apply_t5_attn_mask: bool = False, + ) -> List[torch.Tensor]: + # supports single model inference only + + clip_l, t5xxl = models + l_tokens, t5_tokens = tokens[:2] + t5_attn_mask = tokens[2] if len(tokens) > 2 else None + + if clip_l is not None and l_tokens is not None: + l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"] + else: + l_pooled = None + + if t5xxl is not None and t5_tokens is not None: + # t5_out is [1, max length, 4096] + t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True) + if apply_t5_attn_mask: + t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) + txt_ids = torch.zeros(1, t5_out.shape[1], 3, device=t5_out.device) + else: + t5_out = None + txt_ids = None + + return [l_pooled, t5_out, txt_ids] + + +class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz" + + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + apply_t5_attn_mask: bool = False, + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + self.apply_t5_attn_mask = apply_t5_attn_mask + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, npz_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "l_pooled" not in npz: + return False + if "t5_out" not in npz: + return False + if "txt_ids" not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray: + return t5_out * np.expand_dims(t5_attn_mask, -1) + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + l_pooled = data["l_pooled"] + t5_out = data["t5_out"] + txt_ids = data["txt_ids"] + + if self.apply_t5_attn_mask: + t5_attn_mask = data["t5_attn_mask"] + t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) + + return [l_pooled, t5_out, txt_ids] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy + captions = [info.caption for info in infos] + + tokens_and_masks = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens_and_masks, self.apply_t5_attn_mask + ) + + if l_pooled.dtype == torch.bfloat16: + l_pooled = l_pooled.float() + if t5_out.dtype == torch.bfloat16: + t5_out = t5_out.float() + if txt_ids.dtype == torch.bfloat16: + txt_ids = txt_ids.float() + + l_pooled = l_pooled.cpu().numpy() + t5_out = t5_out.cpu().numpy() + txt_ids = txt_ids.cpu().numpy() + + for i, info in enumerate(infos): + l_pooled_i = l_pooled[i] + t5_out_i = t5_out[i] + txt_ids_i = txt_ids[i] + + if self.cache_to_disk: + t5_attn_mask = tokens_and_masks[2] + t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() + np.savez( + info.text_encoder_outputs_npz, + l_pooled=l_pooled_i, + t5_out=t5_out_i, + txt_ids=txt_ids_i, + t5_attn_mask=t5_attn_mask_i, + ) + else: + info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i) + + +class FluxLatentsCachingStrategy(LatentsCachingStrategy): + FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz" + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX) + if len(npz_file) == 0: + return None, None + w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) + + +if __name__ == "__main__": + # test code for FluxTokenizeStrategy + # tokenizer = sd3_models.SD3Tokenizer() + strategy = FluxTokenizeStrategy(256) + text = "hello world" + + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + # print(l_tokens.shape) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + texts = ["hello world", "the quick brown fox jumps over the lazy dog"] + l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens_2 = strategy.t5xxl( + texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + print(l_tokens_2) + print(g_tokens_2) + print(t5_tokens_2) + + # compare + print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0])) + print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0])) + print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0])) + + text = ",".join(["hello world! this is long text"] * 50) + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + print(f"model max length l: {strategy.clip_l.model_max_length}") + print(f"model max length g: {strategy.clip_g.model_max_length}") + print(f"model max length t5: {strategy.t5xxl.model_max_length}") diff --git a/networks/lora_flux.py b/networks/lora_flux.py new file mode 100644 index 000000000..141137b46 --- /dev/null +++ b/networks/lora_flux.py @@ -0,0 +1,730 @@ +# temporary minimum implementation of LoRA +# FLUX doesn't have Conv2d, so we ignore it +# TODO commonize with the original implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel +import numpy as np +import torch +import re +from library.utils import setup_logging +from library.sdxl_original_unet import SdxlUNet2DConditionModel + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) # calc in float + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + # get up/down weight + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def set_region(self, region): + self.region = region + self.region_mask = None + + def default_forward(self, x): + # logger.info(f"default_forward {self.lora_name} {x.size()}") + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + ae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + flux, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + flux, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + varbose=True, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork(text_encoders, flux, multiplier=multiplier, module_class=module_class) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_FLUX = "lora_flux" + LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" + + def __init__( + self, + text_encoders: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + if self.conv_lora_dim is not None: + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + + # create module instances + def create_modules( + is_flux: bool, text_encoder_idx: Optional[int], root_module: torch.nn.Module, target_replace_modules: List[str] + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_FLUX + if is_flux + else (self.LORA_PREFIX_TEXT_ENCODER_CLIP if text_encoder_idx == 0 else self.LORA_PREFIX_TEXT_ENCODER_T5) + ) + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + dim = None + alpha = None + + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + ) + loras.append(lora) + return loras, skipped + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + index = i + logger.info(f"create LoRA for Text Encoder {index+1}:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.FLUX_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + skipped = skipped_te + skipped_un + if varbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, flux, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP) or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_FLUX): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) + # if ( + # self.loraplus_lr_ratio is not None + # or self.loraplus_text_encoder_lr_ratio is not None + # or self.loraplus_unet_lr_ratio is not None + # ): + # assert ( + # optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower() + # ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません" + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + params, descriptions = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + # if self.block_lr: + # is_sdxl = False + # for lora in self.unet_loras: + # if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name: + # is_sdxl = True + # break + + # # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 + # block_idx_to_lora = {} + # for lora in self.unet_loras: + # idx = get_block_index(lora.lora_name, is_sdxl) + # if idx not in block_idx_to_lora: + # block_idx_to_lora[idx] = [] + # block_idx_to_lora[idx].append(lora) + + # # blockごとにパラメータを設定する + # for idx, block_loras in block_idx_to_lora.items(): + # params, descriptions = assemble_params( + # block_loras, + # (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx), + # self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + # ) + # all_params.extend(params) + # lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions]) + + # else: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 67ccae62c..4d6e3f184 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -52,6 +52,11 @@ def load_target_model(self, args, weight_dtype, accelerator): self.logit_scale = logit_scale self.ckpt_info = ckpt_info + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) + return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet def get_tokenize_strategy(self, args): diff --git a/train_network.py b/train_network.py index 3828fed19..48d988624 100644 --- a/train_network.py +++ b/train_network.py @@ -100,6 +100,12 @@ def assert_extra_args(self, args, train_dataset_group): def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet def get_tokenize_strategy(self, args): @@ -147,6 +153,81 @@ def all_reduce_network(self, accelerator, network): def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet) + # region SD/SDXL + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images).latent_dist.sample() + + def shift_scale_latents(self, args, latents): + return latents * self.vae_scale_factor + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + ): + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = self.call_unet( + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + timesteps, + text_encoder_conds, + batch, + weight_dtype, + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + return noise_pred, target, timesteps, huber_c, None + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + return loss + + # endregion + def train(self, args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -253,11 +334,6 @@ def train(self, args): # text_encoder is List[CLIPTextModel] or CLIPTextModel text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(args.xformers) - # 差分追加学習のためにモデルを読み込む sys.path.append(os.path.dirname(__file__)) accelerator.print("import network module:", args.network_module) @@ -445,16 +521,19 @@ def train(self, args): unet_weight_dtype = torch.float8_e4m3fn te_weight_dtype = torch.float8_e4m3fn + unet.to(accelerator.device) # this makes faster `to(dtype)` below + unet.requires_grad_(False) - unet.to(dtype=unet_weight_dtype) + unet.to(dtype=unet_weight_dtype) # this takes long time and large memory for t_enc in text_encoders: t_enc.requires_grad_(False) # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 if t_enc.device.type != "cpu": t_enc.to(dtype=te_weight_dtype) - # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + if hasattr(t_enc.text_model, "embeddings"): + # nn.Embedding not support FP8 + t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: @@ -851,12 +930,7 @@ def load_model_hook(models, input_dir): global_step = 0 - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + noise_scheduler = self.get_noise_scheduler(args, accelerator.device) if accelerator.is_main_process: init_kwargs = {} @@ -913,6 +987,13 @@ def remove_model(old_ckpt_name): initial_step -= len(train_dataloader) global_step = initial_step + # log device and dtype for each model + logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") + for t_enc in text_encoders: + logger.info(f"text_encoder dtype: {te_weight_dtype}, device: {t_enc.device}") + + clean_memory_on_device(accelerator.device) + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -940,13 +1021,15 @@ def remove_model(old_ckpt_name): else: with torch.no_grad(): # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) + latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype)) + latents = latents.to(dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) - latents = latents * self.vae_scale_factor + + latents = self.shift_scale_latents(args, latents) # get multiplier for each sample if network_has_multiplier: @@ -985,41 +1068,25 @@ def remove_model(old_ckpt_name): if args.full_fp16: text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents + # sample noise, call unet, get target + noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, ) - # ensure the hidden state will require grad - if args.gradient_checkpointing: - for x in noisy_latents: - x.requires_grad_(True) - for t in text_encoder_conds: - t.requires_grad_(True) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = self.call_unet( - args, - accelerator, - unet, - noisy_latents.requires_grad_(train_unet), - timesteps, - text_encoder_conds, - batch, - weight_dtype, - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - loss = train_util.conditional_loss( noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) + if weighting is not None: + loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -1027,14 +1094,8 @@ def remove_model(old_ckpt_name): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし From 808d2d1f48e2f4e544d47464edb2727c03da2f53 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 9 Aug 2024 23:02:51 +0900 Subject: [PATCH 061/348] fix typos --- flux_train_network.py | 2 +- library/flux_models.py | 4 ++-- library/flux_utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 7c762c86d..e4be97ad8 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -250,7 +250,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): # ) with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( img=packed_noisy_model_input, img_ids=img_ids, diff --git a/library/flux_models.py b/library/flux_models.py index d0955e375..92c79bcca 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -685,11 +685,11 @@ def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[T attn = attention(q, k, v, pe=pe) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] - # calculate the img bloks + # calculate the img blocks img = img + img_mod1.gate * self.img_attn.proj(img_attn) img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) - # calculate the txt bloks + # calculate the txt blocks txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) return img, txt diff --git a/library/flux_utils.py b/library/flux_utils.py index ba828d508..166cd833b 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -20,7 +20,7 @@ def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux: - logger.info(f"Bulding Flux model {name}") + logger.info(f"Building Flux model {name}") with torch.device("meta"): model = flux_models.Flux(flux_models.configs[name].params).to(dtype) From 358f13f2c92a04fb524006f124fc029a9edb0eaf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 10 Aug 2024 14:03:59 +0900 Subject: [PATCH 062/348] fix alpha is ignored --- networks/lora_flux.py | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 141137b46..332a73d97 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -307,7 +307,9 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh module_class = LoRAInfModule if for_inference else LoRAModule - network = LoRANetwork(text_encoders, flux, multiplier=multiplier, module_class=module_class) + network = LoRANetwork( + text_encoders, flux, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + ) return network, weights_sd @@ -331,6 +333,8 @@ def __init__( conv_lora_dim: Optional[int] = None, conv_alpha: Optional[float] = None, module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, varbose: Optional[bool] = False, ) -> None: super().__init__() @@ -348,12 +352,15 @@ def __init__( self.loraplus_unet_lr_ratio = None self.loraplus_text_encoder_lr_ratio = None - logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - logger.info( - f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" - ) - if self.conv_lora_dim is not None: - logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + if self.conv_lora_dim is not None: + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules( @@ -381,13 +388,19 @@ def create_modules( dim = None alpha = None - # 通常、すべて対象とする - if is_linear or is_conv2d_1x1: - dim = self.lora_dim - alpha = self.alpha - elif self.conv_lora_dim is not None: - dim = self.conv_lora_dim - alpha = self.conv_alpha + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha if dim is None or dim == 0: # skipした情報を出力 From 8a0f12dde812994ec3facdcdb7c08b362dbceb0f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 10 Aug 2024 23:42:05 +0900 Subject: [PATCH 063/348] update FLUX LoRA training --- README.md | 29 ++++++++--- flux_train_network.py | 105 ++++++++++++++++++++++++++++++-------- library/sai_model_spec.py | 24 +++++++-- library/strategy_flux.py | 4 +- library/train_util.py | 9 ++-- networks/lora_flux.py | 2 +- train_network.py | 18 +++++-- 7 files changed, 150 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index a0b02f108..1089dd001 100644 --- a/README.md +++ b/README.md @@ -2,24 +2,41 @@ This repository contains training, generation and utility scripts for Stable Dif ## FLUX.1 LoRA training (WIP) -__Aug 9, 2024__: +This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. + +Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. -We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below. It will work with 24GB VRAM GPUs. ``` -accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0 --loss_type l2 ``` +LoRAs for Text Encoders are not tested yet. + +We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_scale`, `--model_prediction_type` and `--discrete_flow_shift`. The options are as follows: + +- `--timestep_sampling` is the method to sample timesteps (0-1): `sigma` (sigma-based, same as SD3), `uniform` (uniform random), or `sigmoid` (sigmoid of random normal, same as x-flux). +- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform. +- `--model_prediction_type` is how to interpret and process the model prediction: `raw` (use as is, same as x-flux), `additive` (add to noisy input), `sigma_scaled` (apply sigma scaling, same as SD3). +- `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3). + +`--loss_type` may be useful for FLUX.1 training. The default is `l2`. + +In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. Other settings may work better, so please try different settings. + +We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. + +The trained LoRA model can be used with ComfyUI. + The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. ``` -python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors +python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` -Unfortnately the training result is not good. Please let us know if you have any idea to improve the training. - ## SD3 training SD3 training is done with `sd3_train.py`. diff --git a/flux_train_network.py b/flux_train_network.py index e4be97ad8..69b6e8eaf 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -135,7 +135,7 @@ def sample_images(self, accelerator, args, epoch, global_step, device, vae, toke pass def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) return noise_scheduler @@ -211,21 +211,32 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image - # for weighting schemes where we sample timesteps non-uniformly - u = compute_density_for_timestep_sampling( - weighting_scheme=args.weighting_scheme, - batch_size=bsz, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, - ) - indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) - - # Add noise according to flow matching. - sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random t-based noise sampling + if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main + t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=accelerator.device)) + else: + t = torch.rand((bsz,), device=accelerator.device) + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * latents + t * noise + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) + + # Add noise according to flow matching. + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 @@ -264,11 +275,20 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - model_pred = model_pred * (-sigmas) + noisy_model_input - - # these weighting schemes use a uniform timestep sampling - # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + if args.model_prediction_type == "raw": + # use model_pred as is + weighting = None + elif args.model_prediction_type == "additive": + # add the model_pred to the noisy_model_input + model_pred = model_pred + noisy_model_input + weighting = None + elif args.model_prediction_type == "sigma_scaled": + # apply sigma scaling + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss: this is different from SD3 target = noise - latents @@ -278,6 +298,21 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + + def update_metadata(self, metadata, args): + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_guidance_scale"] = args.guidance_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() @@ -318,6 +353,34 @@ def setup_parser() -> argparse.ArgumentParser: default=3.5, help="the FLUX.1 dev variant is a guidance distilled model", ) + + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--model_prediction_type", + choices=["raw", "additive", "sigma_scaled"], + default="sigma_scaled", + help="How to interpret and process the model prediction: " + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." + " / モデル予測の解釈と処理方法:" + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=3.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", + ) return parser diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index af073677e..ad72ec00d 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -59,6 +59,8 @@ ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" ARCH_SD3_M = "stable-diffusion-3-medium" ARCH_SD3_UNKNOWN = "stable-diffusion-3" +ARCH_FLUX_1_DEV = "flux-1-dev" +ARCH_FLUX_1_UNKNOWN = "flux-1" ADAPTER_LORA = "lora" ADAPTER_TEXTUAL_INVERSION = "textual-inversion" @@ -66,6 +68,7 @@ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI" IMPL_DIFFUSERS = "diffusers" +IMPL_FLUX = "https://github.com/black-forest-labs/flux" PRED_TYPE_EPSILON = "epsilon" PRED_TYPE_V = "v" @@ -118,10 +121,11 @@ def build_metadata( merged_from: Optional[str] = None, timesteps: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, - sd3: str = None, + sd3: Optional[str] = None, + flux: Optional[str] = None, ): """ - sd3: only supports "m" + sd3: only supports "m", flux: only supports "dev" """ # if state_dict is None, hash is not calculated @@ -140,6 +144,11 @@ def build_metadata( arch = ARCH_SD3_M else: arch = ARCH_SD3_UNKNOWN + elif flux is not None: + if flux == "dev": + arch = ARCH_FLUX_1_DEV + else: + arch = ARCH_FLUX_1_UNKNOWN elif v2: if v_parameterization: arch = ARCH_SD_V2_768_V @@ -158,7 +167,10 @@ def build_metadata( if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion - if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: + if flux is not None: + # Flux + impl = IMPL_FLUX + elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: # Stable Diffusion ckpt, TI, SDXL LoRA impl = IMPL_STABILITY_AI else: @@ -216,7 +228,7 @@ def build_metadata( reso = (reso[0], reso[0]) else: # resolution is defined in dataset, so use default - if sdxl or sd3 is not None: + if sdxl or sd3 is not None or flux is not None: reso = 1024 elif v2 and v_parameterization: reso = 768 @@ -227,7 +239,9 @@ def build_metadata( metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}" - if v_parameterization: + if flux is not None: + del metadata["modelspec.prediction_type"] + elif v_parameterization: metadata["modelspec.prediction_type"] = PRED_TYPE_V else: metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON diff --git a/library/strategy_flux.py b/library/strategy_flux.py index f194ccf6e..13459d32f 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -63,11 +63,11 @@ def encode_tokens( l_pooled = None if t5xxl is not None and t5_tokens is not None: - # t5_out is [1, max length, 4096] + # t5_out is [b, max length, 4096] t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True) if apply_t5_attn_mask: t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) - txt_ids = torch.zeros(1, t5_out.shape[1], 3, device=t5_out.device) + txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device) else: t5_out = None txt_ids = None diff --git a/library/train_util.py b/library/train_util.py index fc458a884..6b74bb3fa 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3186,6 +3186,7 @@ def get_sai_model_spec( textual_inversion: bool, is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA sd3: str = None, + flux: str = None, ): timestamp = time.time() @@ -3220,6 +3221,7 @@ def get_sai_model_spec( timesteps=timesteps, clip_skip=args.clip_skip, # None or int sd3=sd3, + flux=flux, ) return metadata @@ -3642,8 +3644,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--loss_type", type=str, default="l2", - choices=["l2", "huber", "smooth_l1"], - help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2", + choices=["l1", "l2", "huber", "smooth_l1"], + help="The type of loss function to use (L1, L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1)、デフォルトはL2", ) parser.add_argument( "--huber_schedule", @@ -5359,9 +5361,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): def conditional_loss( model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1 ): - if loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) + elif loss_type == "l1": + loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) elif loss_type == "huber": loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 332a73d97..a4dab287a 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -316,7 +316,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh class LoRANetwork(torch.nn.Module): FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] - LORA_PREFIX_FLUX = "lora_flux" + LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" diff --git a/train_network.py b/train_network.py index 48d988624..367203f54 100644 --- a/train_network.py +++ b/train_network.py @@ -226,6 +226,12 @@ def post_process_loss(self, loss, args, timesteps, noise_scheduler): loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) return loss + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) + + def update_metadata(self, metadata, args): + pass + # endregion def train(self, args): @@ -521,10 +527,13 @@ def train(self, args): unet_weight_dtype = torch.float8_e4m3fn te_weight_dtype = torch.float8_e4m3fn - unet.to(accelerator.device) # this makes faster `to(dtype)` below + # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM + # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory + + unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above unet.requires_grad_(False) - unet.to(dtype=unet_weight_dtype) # this takes long time and large memory + unet.to(dtype=unet_weight_dtype) for t_enc in text_encoders: t_enc.requires_grad_(False) @@ -718,8 +727,11 @@ def load_model_hook(models, input_dir): "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, "ss_huber_c": args.huber_c, + "ss_fp8_base": args.fp8_base, } + self.update_metadata(metadata, args) # architecture specific metadata + if use_user_config: # save metadata of multiple datasets # NOTE: pack "ss_datasets" value as json one time @@ -964,7 +976,7 @@ def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False metadata["ss_epoch"] = str(epoch_no) metadata_to_save = minimum_metadata if args.no_metadata else metadata - sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) + sai_metadata = self.get_sai_model_spec(args) metadata_to_save.update(sai_metadata) unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save) From 82314ac2e7926ed15eac6306bebe4ffb78280346 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 11 Aug 2024 11:14:08 +0900 Subject: [PATCH 064/348] update readme for ai toolkit settings --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1089dd001..d016bcec4 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,11 @@ We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_sca `--loss_type` may be useful for FLUX.1 training. The default is `l2`. -In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. Other settings may work better, so please try different settings. +In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. + +additional note (Aug 11): A quick check shows that the settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). + +Other settings may work better, so please try different settings. We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. From d25ae361d06bb6f49c104ca2e6b4a9188a88c95f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 11 Aug 2024 19:07:07 +0900 Subject: [PATCH 065/348] fix apply_t5_attn_mask to work --- README.md | 2 ++ flux_train_network.py | 6 ++++-- library/strategy_flux.py | 18 +++++++++++++----- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index d016bcec4..d47776ca6 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ This repository contains training, generation and utility scripts for Stable Dif This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. +Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before. + Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. diff --git a/flux_train_network.py b/flux_train_network.py index 69b6e8eaf..59a666aae 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -67,14 +67,16 @@ def get_latents_caching_strategy(self, args): return latents_caching_strategy def get_text_encoding_strategy(self, args): - return strategy_flux.FluxTextEncodingStrategy() + return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) def get_models_for_text_encoding(self, args, accelerator, text_encoders): return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])] def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: - return strategy_flux.FluxTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False) + return strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False, apply_t5_attn_mask=args.apply_t5_attn_mask + ) else: return None diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 13459d32f..3880a1e1b 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -41,17 +41,24 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: class FluxTextEncodingStrategy(TextEncodingStrategy): - def __init__(self) -> None: - pass + def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None: + """ + Args: + apply_t5_attn_mask: Default value for apply_t5_attn_mask. + """ + self.apply_t5_attn_mask = apply_t5_attn_mask def encode_tokens( self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], - apply_t5_attn_mask: bool = False, + apply_t5_attn_mask: Optional[bool] = None, ) -> List[torch.Tensor]: - # supports single model inference only + # supports single model inference + + if apply_t5_attn_mask is None: + apply_t5_attn_mask = self.apply_t5_attn_mask clip_l, t5xxl = models l_tokens, t5_tokens = tokens[:2] @@ -137,8 +144,9 @@ def cache_batch_outputs( tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): + # attn_mask is not applied when caching to disk: it is applied when loading from disk l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens_and_masks, self.apply_t5_attn_mask + tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk ) if l_pooled.dtype == torch.bfloat16: From 9e09a69df1ea8aa76ec98df3b2eed961c66432e4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 12 Aug 2024 08:19:45 +0900 Subject: [PATCH 066/348] update README --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d47776ca6..ccc83e6e8 100644 --- a/README.md +++ b/README.md @@ -10,10 +10,10 @@ Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to mak Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. -We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below. It will work with 24GB VRAM GPUs. +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. ``` -accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0 --loss_type l2 +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2 ``` LoRAs for Text Encoders are not tested yet. @@ -29,7 +29,7 @@ We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_sca In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. -additional note (Aug 11): A quick check shows that the settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). +additional note (Aug 11): A quick check shows that the settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). This seems to be a good starting point. Thanks to Ostris for the great work! Other settings may work better, so please try different settings. From 4af36f96320d553025cfdf067cae1e346af44a67 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Mon, 12 Aug 2024 13:24:10 +0900 Subject: [PATCH 067/348] update to work interactive mode --- README.md | 2 ++ flux_minimal_inference.py | 33 +++++++++++++++++++++++++++------ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index ccc83e6e8..c0d50a5a2 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ The trained LoRA model can be used with ComfyUI. The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. +Aug 12: `--interactive` option is now working. + ``` python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index f3affca80..b09f63808 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -5,7 +5,7 @@ import math import os import random -from typing import Callable, Optional, Tuple +from typing import Callable, List, Optional, Tuple import einops import numpy as np @@ -121,6 +121,9 @@ def generate_image( steps: Optional[int], guidance: float, ): + seed = seed if seed is not None else random.randint(0, 2**32 - 1) + logger.info(f"Seed: {seed}") + # make first noise with packed shape # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) @@ -183,9 +186,7 @@ def generate_image( steps = 4 if is_schnell else 50 img_ids = img_ids.to(device) - x = do_sample( - accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance_scale, is_schnell, device, flux_dtype - ) + x = do_sample(accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, is_schnell, device, flux_dtype) if args.offload: model = model.cpu() # del model @@ -255,6 +256,7 @@ def generate_image( default=[], help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)", ) + parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") parser.add_argument("--width", type=int, default=target_width) parser.add_argument("--height", type=int, default=target_height) parser.add_argument("--interactive", action="store_true") @@ -341,6 +343,7 @@ def is_fp8(dt): ae = accelerator.prepare(ae) # LoRA + lora_models: List[lora_flux.LoRANetwork] = [] for weights_file in args.lora_weights: if ";" in weights_file: weights_file, multiplier = weights_file.split(";") @@ -351,7 +354,16 @@ def is_fp8(dt): lora_model, weights_sd = lora_flux.create_network_from_weights( multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True ) - lora_model.merge_to([clip_l, t5xxl], model, weights_sd) + if args.merge_lora_weights: + lora_model.merge_to([clip_l, t5xxl], model, weights_sd) + else: + lora_model.apply_to([clip_l, t5xxl], model) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.eval() + lora_model.to(device) + + lora_models.append(lora_model) if not args.interactive: generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) @@ -363,7 +375,9 @@ def is_fp8(dt): guidance = args.guidance while True: - print("Enter prompt (empty to exit). Options: --w --h --s --d --g ") + print( + "Enter prompt (empty to exit). Options: --w --h --s --d --g --m " + ) prompt = input() if prompt == "": break @@ -384,6 +398,13 @@ def is_fp8(dt): seed = int(opt[1:].strip()) elif opt.startswith("g"): guidance = float(opt[1:].strip()) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance) From a7d5dabde3facb57d069eba0aa91e961e04303ad Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 12 Aug 2024 17:09:19 +0900 Subject: [PATCH 068/348] Update readme --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index c0d50a5a2..19aed2212 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,12 @@ We have added a new training script for LoRA training. The script is `flux_train accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2 ``` +The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: + +``` +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"` +``` + LoRAs for Text Encoders are not tested yet. We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_scale`, `--model_prediction_type` and `--discrete_flow_shift`. The options are as follows: From 0415d200f5f3db89e33b33c9b36cb3c3e15d0266 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 13 Aug 2024 21:00:16 +0900 Subject: [PATCH 069/348] update dependencies closes #1450 --- requirements.txt | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index e99775b8a..4ee19b3ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ -accelerate==0.25.0 -transformers==4.36.2 +accelerate==0.33.0 +transformers==4.44.0 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 opencv-python==4.7.0.68 einops==0.7.0 pytorch-lightning==1.9.0 -bitsandbytes==0.43.0 +bitsandbytes==0.43.3 prodigyopt==1.0 lion-pytorch==0.0.6 tensorboard @@ -16,7 +16,7 @@ altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.20.1 +huggingface-hub==0.24.5 # for Image utils imagesize==1.4.1 # for BLIP captioning @@ -38,5 +38,7 @@ imagesize==1.4.1 # open-clip-torch==2.20.0 # For logging rich==13.7.0 +# for T5XXL tokenizer (SD3/FLUX) +sentencepiece==0.2.0 # for kohya_ss library -e . From 9711c96f96038df5fa1a15d073244198b93ef0a2 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 13 Aug 2024 21:03:17 +0900 Subject: [PATCH 070/348] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 19aed2212..3eb034ed4 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-ge Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. -Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. +__Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. From 56d7651f0895c805c403a8db01083a522503eb7d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 13 Aug 2024 22:28:39 +0900 Subject: [PATCH 071/348] add experimental split mode for FLUX --- README.md | 22 +++++- flux_train_network.py | 110 +++++++++++++++++++++++---- library/flux_models.py | 165 +++++++++++++++++++++++++++++++++++++++++ networks/lora_flux.py | 30 ++++++-- 4 files changed, 304 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 3eb034ed4..64b018804 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,22 @@ This repository contains training, generation and utility scripts for Stable Dif This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. +__Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ + +Aug 13, 2024: + +__Experimental__ A network argument `train_blocks` is added to `lora_flux`. This is to select the target blocks of LoRA from FLUX double blocks and single blocks. Specify like `--network_args "train_blocks=single"`. `all` trains both double blocks and single blocks, `double` trains only double blocks, and `single` trains only single blocks. The default (omission) is `all`. + +This argument is available even if `--split_mode` is not specified. + +__Experimental__ `--split_mode` option is added to `flux_train_network.py`. This splits FLUX into double blocks and single blocks for training. By enabling gradients only for the single blocks part, memory usage is reduced. When this option is specified, you need to specify `"train_blocks=single"` in the network arguments. + +This option enables training with 12GB VRAM GPUs, but the training speed is 2-3 times slower than the default. + Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before. Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. -__Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ - We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. ``` @@ -19,7 +29,13 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: ``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"` +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" +``` + +The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_mode` and `train_blocks=single` options. Please use settings like below: + +``` +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" ``` LoRAs for Text Encoders are not tested yet. diff --git a/flux_train_network.py b/flux_train_network.py index 59a666aae..1d1f00d84 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -37,10 +37,16 @@ def assert_extra_args(self, args, train_dataset_group): args.network_train_unet_only or not args.cache_text_encoder_outputs ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" - train_dataset_group.verify_bucket_reso_steps(32) + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models + name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way + # if we load to cpu, flux.to(fp8) takes a long time + model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + + if args.split_mode: + model = self.prepare_split_model(model, weight_dtype, accelerator) clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") clip_l.eval() @@ -49,13 +55,47 @@ def load_target_model(self, args, weight_dtype, accelerator): t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") t5xxl.eval() - name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way - # if we load to cpu, flux.to(fp8) takes a long time - model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + def prepare_split_model(self, model, weight_dtype, accelerator): + from accelerate import init_empty_weights + + logger.info("prepare split model") + with init_empty_weights(): + flux_upper = flux_models.FluxUpper(model.params) + flux_lower = flux_models.FluxLower(model.params) + sd = model.state_dict() + + # lower (trainable) + logger.info("load state dict for lower") + flux_lower.load_state_dict(sd, strict=False, assign=True) + flux_lower.to(dtype=weight_dtype) + + # upper (frozen) + logger.info("load state dict for upper") + flux_upper.load_state_dict(sd, strict=False, assign=True) + + logger.info("prepare upper model") + target_dtype = torch.float8_e4m3fn if args.fp8_base else weight_dtype + flux_upper.to(accelerator.device, dtype=target_dtype) + flux_upper.eval() + + if args.fp8_base: + # this is required to run on fp8 + flux_upper = accelerator.prepare(flux_upper) + + flux_upper.to("cpu") + + self.flux_upper = flux_upper + del model # we don't need model anymore + clean_memory_on_device(accelerator.device) + + logger.info("split model prepared") + + return flux_lower + def get_tokenize_strategy(self, args): return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) @@ -262,17 +302,51 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): # f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}" # ) - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = unet( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - ) + if not args.split_mode: + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + ) + else: + # split forward to reduce memory usage + assert network.train_blocks == "single", "train_blocks must be single for split mode" + with accelerator.autocast(): + # move flux lower to cpu, and then move flux upper to gpu + unet.to("cpu") + clean_memory_on_device(accelerator.device) + self.flux_upper.to(accelerator.device) + + # upper model does not require grad + with torch.no_grad(): + intermediate_img, intermediate_txt, vec, pe = self.flux_upper( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + ) + + # move flux upper back to cpu, and then move flux lower to gpu + self.flux_upper.to("cpu") + clean_memory_on_device(accelerator.device) + unet.to(accelerator.device) + + # lower model requires grad + intermediate_img.requires_grad_(True) + intermediate_txt.requires_grad_(True) + vec.requires_grad_(True) + pe.requires_grad_(True) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe) # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) @@ -331,6 +405,12 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", ) + parser.add_argument( + "--split_mode", + action="store_true", + help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", + ) # copy from Diffusers parser.add_argument( diff --git a/library/flux_models.py b/library/flux_models.py index 92c79bcca..3c7766b85 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -918,3 +918,168 @@ def forward( img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img + + +class FluxUpper(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks: + block.enable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + return img, txt, vec, pe + + +class FluxLower(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams): + super().__init__() + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.out_channels = params.in_channels + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + for block in self.single_blocks: + block.enable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + for block in self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def forward( + self, + img: Tensor, + txt: Tensor, + vec: Tensor | None = None, + pe: Tensor | None = None, + ) -> Tensor: + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/networks/lora_flux.py b/networks/lora_flux.py index a4dab287a..4da33542f 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -252,6 +252,11 @@ def create_network( if module_dropout is not None: module_dropout = float(module_dropout) + # single or double blocks + train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" + if train_blocks is not None: + assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -264,6 +269,7 @@ def create_network( module_dropout=module_dropout, conv_lora_dim=conv_dim, conv_alpha=conv_alpha, + train_blocks=train_blocks, varbose=True, ) @@ -314,9 +320,11 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh class LoRANetwork(torch.nn.Module): - FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] + # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] - LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible + LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" @@ -335,6 +343,7 @@ def __init__( module_class: Type[object] = LoRAModule, modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, + train_blocks: Optional[str] = None, varbose: Optional[bool] = False, ) -> None: super().__init__() @@ -347,6 +356,7 @@ def __init__( self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.train_blocks = train_blocks if train_blocks is not None else "all" self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -360,7 +370,9 @@ def __init__( f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" ) if self.conv_lora_dim is not None: - logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + ) # create module instances def create_modules( @@ -434,9 +446,17 @@ def create_modules( skipped_te += skipped logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + # create LoRA for U-Net + if self.train_blocks == "all": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + elif self.train_blocks == "single": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + elif self.train_blocks == "double": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] - self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.FLUX_TARGET_REPLACE_MODULE) - logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules) + logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: From 9760d097b0bd7efbeb065d4320b2216a94e76efd Mon Sep 17 00:00:00 2001 From: DukeG Date: Wed, 14 Aug 2024 19:58:54 +0800 Subject: [PATCH 072/348] Fix AttributeError: 'T5EncoderModel' object has no attribute 'text_model' While loading T5 model in GPU. --- train_network.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 367203f54..405aa747c 100644 --- a/train_network.py +++ b/train_network.py @@ -540,9 +540,13 @@ def train(self, args): # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 if t_enc.device.type != "cpu": t_enc.to(dtype=te_weight_dtype) - if hasattr(t_enc.text_model, "embeddings"): + if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"): # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + t_enc.text_model.embeddings.to( + dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"): + t_enc.encoder.embeddings.to( + dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: From 7db422211907df3c50703b419655202276a53301 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 14 Aug 2024 22:15:26 +0900 Subject: [PATCH 073/348] add sample image generation during training --- README.md | 2 + flux_train_network.py | 67 +++++++- library/flux_train_utils.py | 297 ++++++++++++++++++++++++++++++++++++ train_network.py | 13 +- 4 files changed, 374 insertions(+), 5 deletions(-) create mode 100644 library/flux_train_utils.py diff --git a/README.md b/README.md index 64b018804..7dc954fbc 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,8 @@ This feature is experimental. The options and the training script may change in __Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ +Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified. + Aug 13, 2024: __Experimental__ A network argument `train_blocks` is added to `lora_flux`. This is to select the target blocks of LoRA from FLUX double blocks and single blocks. Specify like `--network_args "train_blocks=single"`. `all` trains both double blocks and single blocks, `double` trains only double blocks, and `single` trains only single blocks. The default (omission) is `all`. diff --git a/flux_train_network.py b/flux_train_network.py index 1d1f00d84..b8ea56223 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -10,7 +10,7 @@ init_ipex() -from library import flux_models, flux_utils, sd3_train_utils, sd3_utils, sdxl_model_util, sdxl_train_util, strategy_flux, train_util +from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util import train_network from library.utils import setup_logging @@ -28,6 +28,12 @@ def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + if args.cache_text_encoder_outputs: assert ( train_dataset_group.is_text_encoder_output_cacheable() @@ -139,8 +145,31 @@ def cache_text_encoder_outputs_if_needed( text_encoders[1].to(accelerator.device, dtype=weight_dtype) with accelerator.autocast(): dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) + + # cache sample prompts + self.sample_prompts_te_outputs = None + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = sd3_train_utils.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + accelerator.wait_for_everyone() + # move back to cpu logger.info("move text encoders back to cpu") text_encoders[0].to("cpu") # , dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu") # , dtype=torch.float32) @@ -172,9 +201,36 @@ def cache_text_encoder_outputs_if_needed( # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) # return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - # logger.warning("Sampling images is not supported for Flux model") - pass + def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): + if not args.split_mode: + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs + ) + return + + class FluxUpperLowerWrapper(torch.nn.Module): + def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): + super().__init__() + self.flux_upper = flux_upper + self.flux_lower = flux_lower + self.target_device = device + + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None): + self.flux_lower.to("cpu") + clean_memory_on_device(self.target_device) + self.flux_upper.to(self.target_device) + img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance) + self.flux_upper.to("cpu") + clean_memory_on_device(self.target_device) + self.flux_lower.to(self.target_device) + return self.flux_lower(img, txt, vec, pe) + + wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) + clean_memory_on_device(accelerator.device) + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs + ) + clean_memory_on_device(accelerator.device) def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) @@ -389,6 +445,9 @@ def update_metadata(self, metadata, args): metadata["ss_model_prediction_type"] = args.model_prediction_type metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py new file mode 100644 index 000000000..91f522389 --- /dev/null +++ b/library/flux_train_utils.py @@ -0,0 +1,297 @@ +import argparse +import math +import os +import numpy as np +import toml +import json +import time +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from accelerate import Accelerator, PartialState +from transformers import CLIPTextModel +from tqdm import tqdm +from PIL import Image + +from library import flux_models, flux_utils, strategy_base +from library.sd3_train_utils import load_prompts +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + flux, + ae, + text_encoders, + sample_prompts_te_outputs, + prompt_replacement=None, +): + if steps == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts): + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap unet and text_encoder(s) + flux = accelerator.unwrap_model(flux) + text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = load_prompts(args.sample_prompts) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, + args, + flux, + text_encoders, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, + args, + flux, + text_encoders, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + clean_memory_on_device(accelerator.device) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + flux: flux_models.Flux, + text_encoders: List[CLIPTextModel], + ae: flux_models.AutoEncoder, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, +): + assert isinstance(prompt_dict, dict) + # negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 20) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 3.5) + seed = prompt_dict.get("seed") + # controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + # if negative_prompt is not None: + # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + # if negative_prompt is None: + # negative_prompt = "" + + height = max(64, height - height % 16) # round to divisible by 16 + width = max(64, width - width % 16) # round to divisible by 16 + logger.info(f"prompt: {prompt}") + # logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: + te_outputs = sample_prompts_te_outputs[prompt] + else: + tokens_and_masks = tokenize_strategy.tokenize(prompt) + te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + + l_pooled, t5_out, txt_ids = te_outputs + + # sample image + weight_dtype = ae.dtype # TOFO give dtype as argument + packed_latent_height = height // 16 + packed_latent_width = width // 16 + noise = torch.randn( + 1, + packed_latent_height * packed_latent_width, + 16 * 2 * 2, + device=accelerator.device, + dtype=weight_dtype, + generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, + ) + timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True + img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) + + with accelerator.autocast(), torch.no_grad(): + x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale) + + x = x.float() + x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) + + # latent to image + clean_memory_on_device(accelerator.device) + org_vae_device = ae.device # will be on cpu + ae.to(accelerator.device) # distributed_state.device is same as accelerator.device + with accelerator.autocast(), torch.no_grad(): + x = ae.decode(x) + ae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + vec: torch.Tensor, + timesteps: list[float], + guidance: float = 4.0, +): + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec) + + img = img + (t_prev - t_curr) * pred + + return img diff --git a/train_network.py b/train_network.py index 367203f54..53d71b57d 100644 --- a/train_network.py +++ b/train_network.py @@ -232,6 +232,9 @@ def get_sai_model_spec(self, args): def update_metadata(self, metadata, args): pass + def is_text_encoder_not_needed_for_training(self, args): + return False # use for sample images + # endregion def train(self, args): @@ -529,7 +532,7 @@ def train(self, args): # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory - + unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above unet.requires_grad_(False) @@ -989,6 +992,14 @@ def remove_model(old_ckpt_name): accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) + # if text_encoder is not needed for training, delete it to save memory. + # TODO this can be automated after SDXL sample prompt cache is implemented + if self.is_text_encoder_not_needed_for_training(args): + logger.info("text_encoder is not needed for training. deleting to save memory.") + for t_enc in text_encoders: + del t_enc + text_encoders = [] + # For --sample_at_first self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) From 8aaa1967bd3d3a9b4b44e97e5432d23f2101cf51 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Aug 2024 22:07:23 +0900 Subject: [PATCH 074/348] fix encoding latents closes #1456 --- flux_train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index b8ea56223..daa65c857 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -238,8 +238,8 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> return noise_scheduler def encode_images_to_latents(self, args, accelerator, vae, images): - return vae.encode(images).latent_dist.sample() - + return vae.encode(images) + def shift_scale_latents(self, args, latents): return latents From 35b6cb0cd1b319d5f34b44a8c24c81c42895fa2e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Aug 2024 22:07:35 +0900 Subject: [PATCH 075/348] update for torchvision --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 7dc954fbc..bdb6bf2ed 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,10 @@ This repository contains training, generation and utility scripts for Stable Dif This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. -__Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ +__Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchvision==0.19.0` with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ + +The command to install PyTorch is as follows: +`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified. From 08ef886bfeb058aa6d6f7e0a19589c0fd80b3757 Mon Sep 17 00:00:00 2001 From: DukeG Date: Fri, 16 Aug 2024 11:00:08 +0800 Subject: [PATCH 076/348] Fix AttributeError: 'FluxNetworkTrainer' object has no attribute 'sample_prompts_te_outputs' Move "self.sample_prompts_te_outputs = None" from Line 150 to Line 26. --- flux_train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux_train_network.py b/flux_train_network.py index daa65c857..59b9d84b5 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -23,6 +23,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() + self.sample_prompts_te_outputs = None def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -147,7 +148,6 @@ def cache_text_encoder_outputs_if_needed( dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) # cache sample prompts - self.sample_prompts_te_outputs = None if args.sample_prompts is not None: logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") From 3921a4efda1cd1d7d873177ea7f51b77c3f15d3d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 16 Aug 2024 17:06:05 +0900 Subject: [PATCH 077/348] add t5xxl max token length, support schnell --- README.md | 8 ++++++++ flux_train_network.py | 32 ++++++++++++++++++++++++++++---- library/flux_models.py | 12 ++++++++---- 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index bdb6bf2ed..6fb050dff 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,14 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 16, 2024: + +FLUX.1 schnell model based training is now supported (but not tested). If the name of the model file contains `schnell`, the model is treated as a schnell model. + +Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. The default is 512 in dev and 256 in schnell. + +Previously, when `--max_token_length` was specified, that value was used, and 512 was used when omitted (default). Therefore, there is no impact if `--max_token_length` was not specified. If `--max_token_length` was specified, please specify `--t5xxl_max_token_length` instead. `--max_token_length` is ignored during FLUX.1 training. + Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified. Aug 13, 2024: diff --git a/flux_train_network.py b/flux_train_network.py index 59b9d84b5..b9a29c160 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -44,11 +44,18 @@ def assert_extra_args(self, args, train_dataset_group): args.network_train_unet_only or not args.cache_text_encoder_outputs ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + def get_flux_model_name(self, args): + return "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" + def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models - name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way + name = self.get_flux_model_name(args) + # if we load to cpu, flux.to(fp8) takes a long time model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") @@ -104,7 +111,18 @@ def prepare_split_model(self, model, weight_dtype, accelerator): return flux_lower def get_tokenize_strategy(self, args): - return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + name = self.get_flux_model_name(args) + + if args.t5xxl_max_token_length is None: + if name == "schnell": + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] @@ -239,7 +257,7 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> def encode_images_to_latents(self, args, accelerator, vae, images): return vae.encode(images) - + def shift_scale_latents(self, args, latents): return latents @@ -470,7 +488,13 @@ def setup_parser() -> argparse.ArgumentParser: help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", ) - + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=None, + help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev" + " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512", + ) # copy from Diffusers parser.add_argument( "--weighting_scheme", diff --git a/library/flux_models.py b/library/flux_models.py index 3c7766b85..ed0bc8c7d 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -863,7 +863,8 @@ def enable_gradient_checkpointing(self): self.time_in.enable_gradient_checkpointing() self.vector_in.enable_gradient_checkpointing() - self.guidance_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() for block in self.double_blocks + self.single_blocks: block.enable_gradient_checkpointing() @@ -875,7 +876,8 @@ def disable_gradient_checkpointing(self): self.time_in.disable_gradient_checkpointing() self.vector_in.disable_gradient_checkpointing() - self.guidance_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() for block in self.double_blocks + self.single_blocks: block.disable_gradient_checkpointing() @@ -972,7 +974,8 @@ def enable_gradient_checkpointing(self): self.time_in.enable_gradient_checkpointing() self.vector_in.enable_gradient_checkpointing() - self.guidance_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() for block in self.double_blocks: block.enable_gradient_checkpointing() @@ -984,7 +987,8 @@ def disable_gradient_checkpointing(self): self.time_in.disable_gradient_checkpointing() self.vector_in.disable_gradient_checkpointing() - self.guidance_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() for block in self.double_blocks: block.disable_gradient_checkpointing() From e45d3f8634c6dd4e358a8c7972f7c851f18f94d3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 16 Aug 2024 22:19:21 +0900 Subject: [PATCH 078/348] add merge LoRA script --- README.md | 24 +++ library/train_util.py | 2 +- networks/flux_merge_lora.py | 361 ++++++++++++++++++++++++++++++++++++ 3 files changed, 386 insertions(+), 1 deletion(-) create mode 100644 networks/flux_merge_lora.py diff --git a/README.md b/README.md index 6fb050dff..e231cc24e 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,8 @@ The command to install PyTorch is as follows: Aug 16, 2024: +Added a script `networks/flux_merge_lora.py` to merge LoRA into FLUX.1 checkpoint. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. + FLUX.1 schnell model based training is now supported (but not tested). If the name of the model file contains `schnell`, the model is treated as a schnell model. Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. The default is 512 in dev and 256 in schnell. @@ -80,6 +82,28 @@ Aug 12: `--interactive` option is now working. python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` +### Merge LoRA to FLUX.1 checkpoint + +`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ + +``` +python networks/flux_merge_lora.py --flux_model flux1-dev.sft --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu +``` + +You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`. + +`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`): + +- 'cpu' / 'cpu': Uses >50GB of RAM, but works on any machine. +- 'cuda' / 'cpu': Uses 24GB of VRAM, but requires 30GB of RAM. +- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cuda' / 'cpu'. + +In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`. + +The script can merge multiple LoRA models. If you want to merge multiple LoRA models, specify `--concat` option to work the merged LoRA model properly. + +``` + ## SD3 training SD3 training is done with `sd3_train.py`. diff --git a/library/train_util.py b/library/train_util.py index 59ec3e56d..fa0eb9e51 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3160,7 +3160,7 @@ def load_metadata_from_safetensors(safetensors_file: str) -> dict: def build_minimum_network_metadata( - v2: Optional[bool], + v2: Optional[str], base_model: Optional[str], network_module: str, network_dim: str, diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py new file mode 100644 index 000000000..c3986ef1f --- /dev/null +++ b/networks/flux_merge_lora.py @@ -0,0 +1,361 @@ +import math +import argparse +import os +import time +import torch +from safetensors import safe_open +from safetensors.torch import load_file, save_file +from tqdm import tqdm +from library import sai_model_spec, train_util +import networks.lora_flux as lora_flux +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == ".safetensors": + sd = load_file(file_name) + metadata = train_util.load_metadata_from_safetensors(file_name) + else: + sd = torch.load(file_name, map_location="cpu") + metadata = {} + + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + + return sd, metadata + + +def save_to_file(file_name, state_dict, dtype, metadata): + if dtype is not None: + logger.info(f"converting to {dtype}...") + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + logger.info(f"saving to: {file_name}") + save_file(state_dict, file_name, metadata=metadata) + + +def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): + # create module map without loading state_dict + logger.info(f"loading keys from FLUX.1 model: {flux_model}") + lora_name_to_module_key = {} + with safe_open(flux_model, framework="pt", device=loading_device) as flux_file: + keys = list(flux_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") + lora_name_to_module_key[lora_name] = key + + flux_state_dict = load_file(flux_model, device=loading_device) + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU + + logger.info(f"merging...") + for key in tqdm(lora_sd.keys()): + if "lora_down" in key: + lora_name = key[: key.rfind(".lora_down")] + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + if lora_name not in lora_name_to_module_key: + logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.") + continue + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + module_weight_key = lora_name_to_module_key[lora_name] + if module_weight_key not in flux_state_dict: + weight = flux_file.get_tensor(module_weight_key) + else: + weight = flux_state_dict[module_weight_key] + + weight = weight.to(working_device, merge_dtype) + up_weight = up_weight.to(working_device, merge_dtype) + down_weight = down_weight.to(working_device, merge_dtype) + + # logger.info(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + del up_weight + del down_weight + del weight + + return flux_state_dict + + +def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): + base_alphas = {} # alpha for merged model + base_dims = {} + + merged_sd = {} + base_model = None + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, lora_metadata = load_state_dict(model, merge_dtype) + + if lora_metadata is not None: + if base_model is None: + base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + + # get alpha and dim + alphas = {} # alpha for current model + dims = {} # dims for current model + for key in lora_sd.keys(): + if "alpha" in key: + lora_module_name = key[: key.rfind(".alpha")] + alpha = float(lora_sd[key].detach().numpy()) + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + elif "lora_down" in key: + lora_module_name = key[: key.rfind(".lora_down")] + dim = lora_sd[key].size()[0] + dims[lora_module_name] = dim + if lora_module_name not in base_dims: + base_dims[lora_module_name] = dim + + for lora_module_name in dims.keys(): + if lora_module_name not in alphas: + alpha = dims[lora_module_name] + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + + # merge + logger.info(f"merging...") + for key in tqdm(lora_sd.keys()): + if "alpha" in key: + continue + + if "lora_up" in key and concat: + concat_dim = 1 + elif "lora_down" in key and concat: + concat_dim = 0 + else: + concat_dim = None + + lora_module_name = key[: key.rfind(".lora_")] + + base_alpha = base_alphas[lora_module_name] + alpha = alphas[lora_module_name] + + scale = math.sqrt(alpha / base_alpha) * ratio + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + + if key in merged_sd: + assert ( + merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None + ), f"weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。" + if concat_dim is not None: + merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + else: + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + else: + merged_sd[key] = lora_sd[key] * scale + + # set alpha to sd + for lora_module_name, alpha in base_alphas.items(): + key = lora_module_name + ".alpha" + merged_sd[key] = torch.tensor(alpha) + if shuffle: + key_down = lora_module_name + ".lora_down.weight" + key_up = lora_module_name + ".lora_up.weight" + dim = merged_sd[key_down].shape[0] + perm = torch.randperm(dim) + merged_sd[key_down] = merged_sd[key_down][perm] + merged_sd[key_up] = merged_sd[key_up][:, perm] + + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + + # check all dims are same + dims_list = list(set(base_dims.values())) + alphas_list = list(set(base_alphas.values())) + all_same_dims = True + all_same_alphas = True + for dims in dims_list: + if dims != dims_list[0]: + all_same_dims = False + break + for alphas in alphas_list: + if alphas != alphas_list[0]: + all_same_alphas = False + break + + # build minimum metadata + dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" + alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" + metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora", dims, alphas, None) + + return merged_sd, metadata + + +def merge(args): + assert len(args.models) == len( + args.ratios + ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + dest_dir = os.path.dirname(args.save_to) + if not os.path.exists(dest_dir): + logger.info(f"creating directory: {dest_dir}") + os.makedirs(dest_dir) + + if args.flux_model is not None: + state_dict = merge_to_flux_model( + args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + ) + + if args.no_metadata: + sai_metadata = None + else: + merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev" + ) + + logger.info(f"saving FLUX model to: {args.save_to}") + save_to_file(args.save_to, state_dict, save_dtype, sai_metadata) + + else: + state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + + logger.info(f"calculating hashes and creating metadata...") + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + if not args.no_metadata: + merged_from = sai_model_spec.build_merged_from(args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" + ) + metadata.update(sai_metadata) + + logger.info(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, save_dtype, metadata) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", + ) + parser.add_argument( + "--precision", + type=str, + default="float", + choices=["float", "fp16", "bf16"], + help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", + ) + parser.add_argument( + "--flux_model", + type=str, + default=None, + help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする", + ) + parser.add_argument( + "--loading_device", + type=str, + default="cpu", + help="device to load FLUX.1 model. LoRA models are loaded on CPU / FLUX.1モデルを読み込むデバイス。LoRAモデルはCPUで読み込まれます", + ) + parser.add_argument( + "--working_device", + type=str, + default="cpu", + help="device to work (merge). Merging LoRA models are done on CPU." + + " / 作業(マージ)するデバイス。LoRAモデルのマージはCPUで行われます。", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + help="destination file name: safetensors file / 保存先のファイル名、safetensorsファイル", + ) + parser.add_argument( + "--models", + type=str, + nargs="*", + help="LoRA models to merge: safetensors file / マージするLoRAモデル、safetensorsファイル", + ) + parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + parser.add_argument( + "--concat", + action="store_true", + help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " + + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle lora weight./ " + "LoRAの重みをシャッフルする", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + merge(args) From 7367584e6749448cb9b012df0d3bcbe4f0531ea5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 17 Aug 2024 14:38:34 +0900 Subject: [PATCH 079/348] fix sd3 training to work without cachine TE outputs #1465 --- sd3_train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index 9c37cbce6..3b6c8a118 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -759,8 +759,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # TODO support weighted captions - input_ids_clip_l = input_ids_clip_l.to(accelerator.device) - input_ids_clip_g = input_ids_clip_g.to(accelerator.device) + # text models in sd3_models require "cpu" for input_ids + input_ids_clip_l = input_ids_clip_l.to("cpu") + input_ids_clip_g = input_ids_clip_g.to("cpu") lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [clip_l, clip_g, None], @@ -770,7 +771,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): if t5_out is None: _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] with torch.no_grad(): - input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None + input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None _, t5_out, _ = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] ) From 400955d3ea4088e8da7a3917dec9b0664424e24a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 17 Aug 2024 15:36:18 +0900 Subject: [PATCH 080/348] add fine tuning FLUX.1 (WIP) --- flux_train.py | 729 ++++++++++++++++++++++++++++++++++++ flux_train_network.py | 168 +-------- library/flux_train_utils.py | 270 ++++++++++++- library/train_util.py | 2 +- 4 files changed, 1007 insertions(+), 162 deletions(-) create mode 100644 flux_train.py diff --git a/flux_train.py b/flux_train.py new file mode 100644 index 000000000..2ca20ded2 --- /dev/null +++ b/flux_train.py @@ -0,0 +1,729 @@ +# training with captions + +import argparse +import copy +import math +import os +from multiprocessing import Value +from typing import List +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux +from library.sd3_train_utils import load_prompts, FlowMatchEulerDiscreteScheduler + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False + ) + ) + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" + + # load VAE for caching latents + ae = None + if cache_latents: + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() + + train_dataset_group.new_cache_latents(ae, accelerator.is_main_process) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.t5xxl_max_token_length is None: + if name == "schnell": + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) + + # load clip_l, t5xxl for caching text encoder outputs + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") + clip_l.eval() + t5xxl.eval() + clip_l.requires_grad_(False) + t5xxl.requires_grad_(False) + + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator.is_main_process) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + clip_l = None + t5xxl = None + + # load FLUX + # if we load to cpu, flux.to(fp8) takes a long time + flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + + if args.gradient_checkpointing: + flux.enable_gradient_checkpointing() + + flux.requires_grad_(True) + + if not cache_latents: + # load VAE here if not cached + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(flux) + params_to_optimize.append({"params": list(flux.parameters()), "lr": args.learning_rate}) + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.fused_optimizer_groups: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # This balances memory usage and management complexity. + + # calculate total number of parameters + n_total_params = sum(len(params["params"]) for params in params_to_optimize) + params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) + + # split params into groups, keeping the learning rate the same for all params in a group + # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + grouped_params = [] + param_group = [] + param_group_lr = -1 + for group in params_to_optimize: + lr = group["lr"] + for p in group["params"]: + # if the learning rate is different for different params, start a new group + if lr != param_group_lr: + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = lr + + param_group.append(p) + + # if the group has enough parameters, start a new group + if len(param_group) == params_per_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = -1 + + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.fused_optimizer_groups: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + flux.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + flux.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=flux) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # acceleratorがなんかよろしくやってくれるらしい + flux = accelerator.prepare(flux) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + elif args.fused_optimizer_groups: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def optimizer_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(optimizer_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + # For --sample_at_first + flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.fused_optimizer_groups: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"]) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance + guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) + + # call model + l_pooled, t5_out, txt_ids = text_encoder_conds + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + loss = train_util.conditional_loss( + model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + ) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.fused_optimizer_groups): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), + ) + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), + ) + + flux_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + ) + + is_main_process = accelerator.is_main_process + # if is_main_process: + flux = accelerator.unwrap_model(flux) + clip_l = accelerator.unwrap_model(clip_l) + clip_g = accelerator.unwrap_model(clip_g) + if t5xxl is not None: + t5xxl = accelerator.unwrap_model(t5xxl) + + accelerator.end_training() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux, ae) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here + flux_train_utils.add_flux_train_arguments(parser) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="skip latents validity check / latentsの正当性チェックをスキップする", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/flux_train_network.py b/flux_train_network.py index b9a29c160..002252c87 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -274,85 +274,14 @@ def get_noise_pred_and_target( weight_dtype, train_unet, ): - # copy from sd3_train.py and modified - - def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): - sigmas = self.noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) - schedule_timesteps = self.noise_scheduler_copy.timesteps.to(accelerator.device) - timesteps = timesteps.to(accelerator.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - return sigma - - def compute_density_for_timestep_sampling( - weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None - ): - """Compute the density for sampling the timesteps when doing SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") - u = torch.nn.functional.sigmoid(u) - elif weighting_scheme == "mode": - u = torch.rand(size=(batch_size,), device="cpu") - u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) - else: - u = torch.rand(size=(batch_size,), device="cpu") - return u - - def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): - """Computes loss weighting scheme for SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "sigma_sqrt": - weighting = (sigmas**-2.0).float() - elif weighting_scheme == "cosmap": - bot = 1 - 2 * sigmas + 2 * sigmas**2 - weighting = 2 / (math.pi * bot) - else: - weighting = torch.ones_like(sigmas) - return weighting - # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] - if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": - # Simple random t-based noise sampling - if args.timestep_sampling == "sigmoid": - # https://github.com/XLabs-AI/x-flux/tree/main - t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=accelerator.device)) - else: - t = torch.rand((bsz,), device=accelerator.device) - timesteps = t * 1000.0 - t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * noise - else: - # Sample a random timestep for each image - # for weighting schemes where we sample timesteps non-uniformly - u = compute_density_for_timestep_sampling( - weighting_scheme=args.weighting_scheme, - batch_size=bsz, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, - ) - indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) - - # Add noise according to flow matching. - sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 @@ -425,20 +354,8 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - if args.model_prediction_type == "raw": - # use model_pred as is - weighting = None - elif args.model_prediction_type == "additive": - # add the model_pred to the noisy_model_input - model_pred = model_pred + noisy_model_input - weighting = None - elif args.model_prediction_type == "sigma_scaled": - # apply sigma scaling - model_pred = model_pred * (-sigmas) + noisy_model_input - - # these weighting schemes use a uniform timestep sampling - # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) # flow matching loss: this is different from SD3 target = noise - latents @@ -469,83 +386,14 @@ def is_text_encoder_not_needed_for_training(self, args): def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() - # sdxl_train_util.add_sdxl_training_arguments(parser) - parser.add_argument("--clip_l", type=str, help="path to clip_l") - parser.add_argument("--t5xxl", type=str, help="path to t5xxl") - parser.add_argument("--ae", type=str, help="path to ae") - parser.add_argument("--apply_t5_attn_mask", action="store_true") - parser.add_argument( - "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" - ) - parser.add_argument( - "--cache_text_encoder_outputs_to_disk", - action="store_true", - help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", - ) + flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument( "--split_mode", action="store_true", help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", ) - parser.add_argument( - "--t5xxl_max_token_length", - type=int, - default=None, - help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev" - " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512", - ) - # copy from Diffusers - parser.add_argument( - "--weighting_scheme", - type=str, - default="none", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - ) - parser.add_argument( - "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." - ) - parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") - parser.add_argument( - "--mode_scale", - type=float, - default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", - ) - parser.add_argument( - "--guidance_scale", - type=float, - default=3.5, - help="the FLUX.1 dev variant is a guidance distilled model", - ) - - parser.add_argument( - "--timestep_sampling", - choices=["sigma", "uniform", "sigmoid"], - default="sigma", - help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。", - ) - parser.add_argument( - "--sigmoid_scale", - type=float, - default=1.0, - help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', - ) - parser.add_argument( - "--model_prediction_type", - choices=["raw", "additive", "sigma_scaled"], - default="sigma_scaled", - help="How to interpret and process the model prediction: " - "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." - " / モデル予測の解釈と処理方法:" - "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", - ) - parser.add_argument( - "--discrete_flow_shift", - type=float, - default=3.0, - help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", - ) return parser diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 91f522389..167d61c7e 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -12,8 +12,9 @@ from transformers import CLIPTextModel from tqdm import tqdm from PIL import Image +from safetensors.torch import save_file -from library import flux_models, flux_utils, strategy_base +from library import flux_models, flux_utils, strategy_base, train_util from library.sd3_train_utils import load_prompts from library.device_utils import init_ipex, clean_memory_on_device @@ -27,6 +28,9 @@ logger = logging.getLogger(__name__) +# region sample images + + def sample_images( accelerator: Accelerator, args: argparse.Namespace, @@ -295,3 +299,267 @@ def denoise( img = img + (t_prev - t_curr) * pred return img + + +# endregion + + +# region train +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz = latents.shape[0] + sigmas = None + + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random t-based noise sampling + if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main + t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + else: + t = torch.rand((bsz,), device=device) + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * latents + t * noise + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=device) + + # Add noise according to flow matching. + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input, timesteps, sigmas + + +def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): + weighting = None + if args.model_prediction_type == "raw": + pass + elif args.model_prediction_type == "additive": + # add the model_pred to the noisy_model_input + model_pred = model_pred + noisy_model_input + elif args.model_prediction_type == "sigma_scaled": + # apply sigma scaling + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + return model_pred, weighting + + +def save_models(ckpt_path: str, flux: flux_models.Flux, sai_metadata: Optional[dict], save_dtype: Optional[torch.dtype] = None): + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("", flux.state_dict()) + + save_file(state_dict, ckpt_path, metadata=sai_metadata) + + +def save_flux_model_on_train_end( + args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") + save_models(ckpt_file, flux, sai_metadata, save_dtype) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_flux_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + flux: flux_models.Flux, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") + save_models(ckpt_file, flux, sai_metadata, save_dtype) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +# endregion + + +def add_flux_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--clip_l", + type=str, + help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument( + "--t5xxl", + type=str, + help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=None, + help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev" + " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + parser.add_argument( + "--text_encoder_batch_size", + type=int, + default=None, + help="text encoder batch size (default: None, use dataset's batch size)" + + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", + ) + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", + ) + + # copy from Diffusers + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--model_prediction_type", + choices=["raw", "additive", "sigma_scaled"], + default="sigma_scaled", + help="How to interpret and process the model prediction: " + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." + " / モデル予測の解釈と処理方法:" + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=3.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", + ) diff --git a/library/train_util.py b/library/train_util.py index fa0eb9e51..f4ac8740a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2629,7 +2629,7 @@ def __getitem__(self, idx): raise NotImplementedError -def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: +def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset: module = ".".join(args.dataset_class.split(".")[:-1]) dataset_class = args.dataset_class.split(".")[-1] module = importlib.import_module(module) From 25f77f6ef04ee760506338e7e7f9835c28657c59 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 17 Aug 2024 15:54:32 +0900 Subject: [PATCH 081/348] fix flux fine tuning to work --- README.md | 4 ++++ flux_train.py | 6 ++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e231cc24e..2b7b110f3 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,10 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` + +Aug 17. 2024: +Added a script `flux_train.py` to train FLUX.1. The script is experimental and not an optimized version. It needs >28GB VRAM for training. + Aug 16, 2024: Added a script `networks/flux_merge_lora.py` to merge LoRA into FLUX.1 checkpoint. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. diff --git a/flux_train.py b/flux_train.py index 2ca20ded2..d2a9b3f32 100644 --- a/flux_train.py +++ b/flux_train.py @@ -674,9 +674,7 @@ def optimizer_hook(parameter: torch.Tensor): # if is_main_process: flux = accelerator.unwrap_model(flux) clip_l = accelerator.unwrap_model(clip_l) - clip_g = accelerator.unwrap_model(clip_g) - if t5xxl is not None: - t5xxl = accelerator.unwrap_model(t5xxl) + t5xxl = accelerator.unwrap_model(t5xxl) accelerator.end_training() @@ -686,7 +684,7 @@ def optimizer_hook(parameter: torch.Tensor): del accelerator # この後メモリを使うのでこれは消す if is_main_process: - flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux, ae) + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux) logger.info("model saved.") From 7e688913aef4c852f54a703c9f91d135b17dff87 Mon Sep 17 00:00:00 2001 From: exveria1015 Date: Sun, 18 Aug 2024 12:38:05 +0900 Subject: [PATCH 082/348] =?UTF-8?q?fix:=20Flux=20=E3=81=AE=20LoRA=20?= =?UTF-8?q?=E3=83=9E=E3=83=BC=E3=82=B8=E6=A9=9F=E8=83=BD=E3=82=92=E4=BF=AE?= =?UTF-8?q?=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- networks/flux_merge_lora.py | 364 +++++++++++++++++++++++++++++------- 1 file changed, 297 insertions(+), 67 deletions(-) diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index c3986ef1f..df0ba606a 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -1,13 +1,14 @@ -import math import argparse +import math import os import time + import torch -from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm + +import lora_flux as lora_flux from library import sai_model_spec, train_util -import networks.lora_flux as lora_flux from library.utils import setup_logging setup_logging() @@ -42,34 +43,181 @@ def save_to_file(file_name, state_dict, dtype, metadata): save_file(state_dict, file_name, metadata=metadata) -def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): - # create module map without loading state_dict +def merge_to_flux_model( + loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype +): logger.info(f"loading keys from FLUX.1 model: {flux_model}") - lora_name_to_module_key = {} - with safe_open(flux_model, framework="pt", device=loading_device) as flux_file: - keys = list(flux_file.keys()) - for key in keys: - if key.endswith(".weight"): - module_name = ".".join(key.split(".")[:-1]) - lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") - lora_name_to_module_key[lora_name] = key - flux_state_dict = load_file(flux_model, device=loading_device) + + def create_key_map(n_double_layers, n_single_layers, hidden_size): + key_map = {} + for index in range(n_double_layers): + prefix_from = f"transformer_blocks.{index}" + prefix_to = f"double_blocks.{index}" + + for end in ("weight", "bias"): + k = f"{prefix_from}.attn." + qkv_img = f"{prefix_to}.img_attn.qkv.{end}" + qkv_txt = f"{prefix_to}.txt_attn.qkv.{end}" + + key_map[f"{k}to_q.{end}"] = (qkv_img, (0, 0, hidden_size)) + key_map[f"{k}to_k.{end}"] = (qkv_img, (0, hidden_size, hidden_size)) + key_map[f"{k}to_v.{end}"] = (qkv_img, (0, hidden_size * 2, hidden_size)) + key_map[f"{k}add_q_proj.{end}"] = (qkv_txt, (0, 0, hidden_size)) + key_map[f"{k}add_k_proj.{end}"] = ( + qkv_txt, + (0, hidden_size, hidden_size), + ) + key_map[f"{k}add_v_proj.{end}"] = ( + qkv_txt, + (0, hidden_size * 2, hidden_size), + ) + + block_map = { + "attn.to_out.0.weight": "img_attn.proj.weight", + "attn.to_out.0.bias": "img_attn.proj.bias", + "norm1.linear.weight": "img_mod.lin.weight", + "norm1.linear.bias": "img_mod.lin.bias", + "norm1_context.linear.weight": "txt_mod.lin.weight", + "norm1_context.linear.bias": "txt_mod.lin.bias", + "attn.to_add_out.weight": "txt_attn.proj.weight", + "attn.to_add_out.bias": "txt_attn.proj.bias", + "ff.net.0.proj.weight": "img_mlp.0.weight", + "ff.net.0.proj.bias": "img_mlp.0.bias", + "ff.net.2.weight": "img_mlp.2.weight", + "ff.net.2.bias": "img_mlp.2.bias", + "ff_context.net.0.proj.weight": "txt_mlp.0.weight", + "ff_context.net.0.proj.bias": "txt_mlp.0.bias", + "ff_context.net.2.weight": "txt_mlp.2.weight", + "ff_context.net.2.bias": "txt_mlp.2.bias", + "attn.norm_q.weight": "img_attn.norm.query_norm.scale", + "attn.norm_k.weight": "img_attn.norm.key_norm.scale", + "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", + "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", + } + + for k, v in block_map.items(): + key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}" + + for index in range(n_single_layers): + prefix_from = f"single_transformer_blocks.{index}" + prefix_to = f"single_blocks.{index}" + + for end in ("weight", "bias"): + k = f"{prefix_from}.attn." + qkv = f"{prefix_to}.linear1.{end}" + key_map[f"{k}to_q.{end}"] = (qkv, (0, 0, hidden_size)) + key_map[f"{k}to_k.{end}"] = (qkv, (0, hidden_size, hidden_size)) + key_map[f"{k}to_v.{end}"] = (qkv, (0, hidden_size * 2, hidden_size)) + key_map[f"{prefix_from}.proj_mlp.{end}"] = ( + qkv, + (0, hidden_size * 3, hidden_size * 4), + ) + + block_map = { + "norm.linear.weight": "modulation.lin.weight", + "norm.linear.bias": "modulation.lin.bias", + "proj_out.weight": "linear2.weight", + "proj_out.bias": "linear2.bias", + "attn.norm_q.weight": "norm.query_norm.scale", + "attn.norm_k.weight": "norm.key_norm.scale", + } + + for k, v in block_map.items(): + key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}" + + return key_map + + key_map = create_key_map( + 18, 1, 2048 + ) # Assuming 18 double layers, 1 single layer, and hidden size of 2048 + + def find_matching_key(flux_dict, lora_key): + lora_key = lora_key.replace("diffusion_model.", "") + lora_key = lora_key.replace("transformer.", "") + lora_key = lora_key.replace("lora_A", "lora_down").replace("lora_B", "lora_up") + lora_key = lora_key.replace("single_transformer_blocks", "single_blocks") + lora_key = lora_key.replace("transformer_blocks", "double_blocks") + + double_block_map = { + "attn.to_out.0": "img_attn.proj", + "norm1.linear": "img_mod.lin", + "norm1_context.linear": "txt_mod.lin", + "attn.to_add_out": "txt_attn.proj", + "ff.net.0.proj": "img_mlp.0", + "ff.net.2": "img_mlp.2", + "ff_context.net.0.proj": "txt_mlp.0", + "ff_context.net.2": "txt_mlp.2", + "attn.norm_q": "img_attn.norm.query_norm", + "attn.norm_k": "img_attn.norm.key_norm", + "attn.norm_added_q": "txt_attn.norm.query_norm", + "attn.norm_added_k": "txt_attn.norm.key_norm", + "attn.to_q": "img_attn.qkv", + "attn.to_k": "img_attn.qkv", + "attn.to_v": "img_attn.qkv", + "attn.add_q_proj": "txt_attn.qkv", + "attn.add_k_proj": "txt_attn.qkv", + "attn.add_v_proj": "txt_attn.qkv", + } + + single_block_map = { + "norm.linear": "modulation.lin", + "proj_out": "linear2", + "attn.norm_q": "norm.query_norm", + "attn.norm_k": "norm.key_norm", + "attn.to_q": "linear1", + "attn.to_k": "linear1", + "attn.to_v": "linear1", + } + + for old, new in double_block_map.items(): + lora_key = lora_key.replace(old, new) + + for old, new in single_block_map.items(): + lora_key = lora_key.replace(old, new) + + if lora_key in key_map: + flux_key = key_map[lora_key] + if isinstance(flux_key, tuple): + flux_key = flux_key[0] + logger.info(f"Found matching key: {flux_key}") + return flux_key + + # If not found in key_map, try partial matching + potential_key = lora_key + ".weight" + logger.info(f"Searching for key: {potential_key}") + matches = [k for k in flux_dict.keys() if potential_key in k] + if matches: + logger.info(f"Found matching key: {matches[0]}") + return matches[0] + return None + + merged_keys = set() for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") - lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU + lora_sd, _ = load_state_dict(model, merge_dtype) - logger.info(f"merging...") + logger.info("merging...") for key in tqdm(lora_sd.keys()): - if "lora_down" in key: - lora_name = key[: key.rfind(".lora_down")] - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" - - if lora_name not in lora_name_to_module_key: - logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.") + if "lora_down" in key or "lora_A" in key: + lora_name = key[ + : key.rfind(".lora_down" if "lora_down" in key else ".lora_A") + ] + up_key = key.replace("lora_down", "lora_up").replace("lora_A", "lora_B") + alpha_key = ( + key[: key.index("lora_down" if "lora_down" in key else "lora_A")] + + "alpha" + ) + + logger.info(f"Processing LoRA key: {lora_name}") + flux_key = find_matching_key(flux_state_dict, lora_name) + + if flux_key is None: + logger.warning(f"no module found for LoRA weight: {key}") continue + logger.info(f"Merging LoRA key {lora_name} into Flux key {flux_key}") + down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -77,40 +225,74 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati alpha = lora_sd.get(alpha_key, dim) scale = alpha / dim - # W <- W + U * D - module_weight_key = lora_name_to_module_key[lora_name] - if module_weight_key not in flux_state_dict: - weight = flux_file.get_tensor(module_weight_key) - else: - weight = flux_state_dict[module_weight_key] + weight = flux_state_dict[flux_key] weight = weight.to(working_device, merge_dtype) up_weight = up_weight.to(working_device, merge_dtype) down_weight = down_weight.to(working_device, merge_dtype) - # logger.info(module_name, down_weight.size(), up_weight.size()) - if len(weight.size()) == 2: - # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + ratio - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) + if lora_name.startswith("transformer."): + if "qkv" in flux_key: + hidden_size = weight.size(-1) // 3 + update = ratio * (up_weight @ down_weight) * scale + + if "img_attn" in flux_key or "txt_attn" in flux_key: + q, k, v = torch.chunk(weight, 3, dim=-1) + if "to_q" in lora_name or "add_q_proj" in lora_name: + q += update.reshape(q.shape) + elif "to_k" in lora_name or "add_k_proj" in lora_name: + k += update.reshape(k.shape) + elif "to_v" in lora_name or "add_v_proj" in lora_name: + v += update.reshape(v.shape) + weight = torch.cat([q, k, v], dim=-1) + else: + if len(weight.size()) == 2: + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + weight = ( + weight + + ratio + * ( + up_weight.squeeze(3).squeeze(2) + @ down_weight.squeeze(3).squeeze(2) + ) + .unsqueeze(2) + .unsqueeze(3) + * scale + ) + else: + conved = torch.nn.functional.conv2d( + down_weight.permute(1, 0, 2, 3), up_weight + ).permute(1, 0, 2, 3) + weight = weight + ratio * conved * scale else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + ratio * conved * scale - - flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + if len(weight.size()) == 2: + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + weight = ( + weight + + ratio + * ( + up_weight.squeeze(3).squeeze(2) + @ down_weight.squeeze(3).squeeze(2) + ) + .unsqueeze(2) + .unsqueeze(3) + * scale + ) + else: + conved = torch.nn.functional.conv2d( + down_weight.permute(1, 0, 2, 3), up_weight + ).permute(1, 0, 2, 3) + weight = weight + ratio * conved * scale + + flux_state_dict[flux_key] = weight.to(loading_device, save_dtype) + merged_keys.add(flux_key) del up_weight del down_weight del weight + logger.info(f"Merged keys: {sorted(list(merged_keys))}") return flux_state_dict @@ -126,7 +308,9 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_metadata is not None: if base_model is None: - base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + base_model = lora_metadata.get( + train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None + ) # get alpha and dim alphas = {} # alpha for current model @@ -152,10 +336,12 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + logger.info( + f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}" + ) # merge - logger.info(f"merging...") + logger.info("merging...") for key in tqdm(lora_sd.keys()): if "alpha" in key: continue @@ -173,14 +359,19 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio - scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + scale = ( + abs(scale) if "lora_up" in key else scale + ) # マイナスの重みに対応する。 if key in merged_sd: assert ( - merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None - ), f"weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。" + merged_sd[key].size() == lora_sd[key].size() + or concat_dim is not None + ), "weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。" if concat_dim is not None: - merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + merged_sd[key] = torch.cat( + [merged_sd[key], lora_sd[key] * scale], dim=concat_dim + ) else: merged_sd[key] = merged_sd[key] + lora_sd[key] * scale else: @@ -199,7 +390,9 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): merged_sd[key_up] = merged_sd[key_up][:, perm] logger.info("merged model") - logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + logger.info( + f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}" + ) # check all dims are same dims_list = list(set(base_dims.values())) @@ -218,15 +411,17 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): # build minimum metadata dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" - metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora", dims, alphas, None) + metadata = train_util.build_minimum_network_metadata( + str(False), base_model, "networks.lora", dims, alphas, None + ) return merged_sd, metadata def merge(args): - assert len(args.models) == len( - args.ratios - ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert ( + len(args.models) == len(args.ratios) + ), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" def str_to_dtype(p): if p == "float": @@ -249,27 +444,48 @@ def str_to_dtype(p): if args.flux_model is not None: state_dict = merge_to_flux_model( - args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + args.loading_device, + args.working_device, + args.flux_model, + args.models, + args.ratios, + merge_dtype, + save_dtype, ) if args.no_metadata: sai_metadata = None else: - merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) + merged_from = sai_model_spec.build_merged_from( + [args.flux_model] + args.models + ) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev" + None, + False, + False, + False, + False, + False, + time.time(), + title=title, + merged_from=merged_from, + flux="dev", ) logger.info(f"saving FLUX model to: {args.save_to}") save_to_file(args.save_to, state_dict, save_dtype, sai_metadata) else: - state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + state_dict, metadata = merge_lora_models( + args.models, args.ratios, merge_dtype, args.concat, args.shuffle + ) - logger.info(f"calculating hashes and creating metadata...") + logger.info("calculating hashes and creating metadata...") - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes( + state_dict, metadata + ) metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash @@ -277,7 +493,16 @@ def str_to_dtype(p): merged_from = sai_model_spec.build_merged_from(args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" + state_dict, + False, + False, + False, + True, + False, + time.time(), + title=title, + merged_from=merged_from, + flux="dev", ) metadata.update(sai_metadata) @@ -332,7 +557,12 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="LoRA models to merge: safetensors file / マージするLoRAモデル、safetensorsファイル", ) - parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument( + "--ratios", + type=float, + nargs="*", + help="ratios for each model / それぞれのLoRAモデルの比率", + ) parser.add_argument( "--no_metadata", action="store_true", From ef535ec6bb99918027afc1e31efa72cd3761d453 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 18 Aug 2024 16:54:18 +0900 Subject: [PATCH 083/348] add memory efficient training for FLUX.1 --- README.md | 64 ++++++++++++-- flux_train.py | 187 +++++++++++++++++++++++++++++------------ library/flux_models.py | 182 ++++++++++++++++++++++++++++++++++----- 3 files changed, 354 insertions(+), 79 deletions(-) diff --git a/README.md b/README.md index 2b7b110f3..521e82e86 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,11 @@ The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` -Aug 17. 2024: +Aug 18, 2024: +Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. + + +Aug 17, 2024: Added a script `flux_train.py` to train FLUX.1. The script is experimental and not an optimized version. It needs >28GB VRAM for training. Aug 16, 2024: @@ -39,11 +43,23 @@ Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-ge Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. -We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. + +### FLUX.1 LoRA training + +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. ``` -accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2 +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py +--pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors +--ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers +--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 +--network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 +--network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base +--highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml +--output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid +--model_prediction_type raw --guidance_scale 1.0 --loss_type l2 ``` +(The command is multi-line for readability. Please combine it into one line.) The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: @@ -80,12 +96,44 @@ The trained LoRA model can be used with ComfyUI. The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. -Aug 12: `--interactive` option is now working. - ``` python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` +### FLUX.1 fine-tuning + +Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended. + +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py +--pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft +--mixed_precision bf16 --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 +--seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 +--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name test-bf16 +--learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1 +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" +--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 +--blockwise_fused_optimizer --double_blocks_to_swap 6 --cpu_offload_checkpointing +``` + +(Combine the command into one line.) + +Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizer`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. + +`--blockwise_fused_optimizer` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training. + +`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--blockwise_fused_optimizer`. + +`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. + +All these options are experimental and may change in the future. + +The increasing the number of blocks to swap may reduce the memory usage, but the training speed will be slower. `--cpu_offload_checkpointing` also slows down the training. + +Swap 6 double blocks and use cpu offload checkpointing may be a good starting point. Please try different settings according to VRAM usage and training speed. + +The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results. + ### Merge LoRA to FLUX.1 checkpoint `networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ @@ -298,7 +346,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. - - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. + - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only Adafactor is supported. Gradient accumulation is not available. - Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`. - Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size. - PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`. @@ -308,7 +356,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer. - Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10. - Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available. - - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. + - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using Adafactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. - Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side. - LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO! @@ -361,7 +409,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 - - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。 + - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は Adafactor のみ対応しています。また gradient accumulation は使えません。 - mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。 - バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。 - PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。 diff --git a/flux_train.py b/flux_train.py index d2a9b3f32..ecb3c7dda 100644 --- a/flux_train.py +++ b/flux_train.py @@ -1,5 +1,15 @@ # training with captions +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + import argparse import copy import math @@ -54,6 +64,12 @@ def train(args): ) args.cache_text_encoder_outputs = True + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True + cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -232,16 +248,25 @@ def train(args): # now we can delete Text Encoders to free memory clip_l = None t5xxl = None + clean_memory_on_device(accelerator.device) # load FLUX # if we load to cpu, flux.to(fp8) takes a long time flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") if args.gradient_checkpointing: - flux.enable_gradient_checkpointing() + flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing) flux.requires_grad_(True) + if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info( + f"enable block swap: double_blocks_to_swap={args.double_blocks_to_swap}, single_blocks_to_swap={args.single_blocks_to_swap}" + ) + flux.enable_block_swap(args.double_blocks_to_swap, args.single_blocks_to_swap) + if not cache_latents: # load VAE here if not cached ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") @@ -265,40 +290,43 @@ def train(args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html - # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. # This balances memory usage and management complexity. - # calculate total number of parameters - n_total_params = sum(len(params["params"]) for params in params_to_optimize) - params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) - - # split params into groups, keeping the learning rate the same for all params in a group - # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + # split params into groups. currently different learning rates are not supported grouped_params = [] - param_group = [] - param_group_lr = -1 + param_group = {} for group in params_to_optimize: - lr = group["lr"] - for p in group["params"]: - # if the learning rate is different for different params, start a new group - if lr != param_group_lr: - if param_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) - param_group = [] - param_group_lr = lr - - param_group.append(p) - - # if the group has enough parameters, start a new group - if len(param_group) == params_per_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) - param_group = [] - param_group_lr = -1 - - if param_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) + named_parameters = list(flux.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_idx = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_idx = int(np[0].split(".")[1]) + block_type = "single" + else: + block_idx = -1 + + param_group_key = (block_type, block_idx) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") # prepare optimizers for each group optimizers = [] @@ -307,7 +335,7 @@ def train(args): optimizers.append(optimizer) optimizer = optimizers[0] # avoid error in the following code - logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) @@ -341,7 +369,7 @@ def train(args): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: # prepare lr schedulers for each optimizer lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] lr_scheduler = lr_schedulers[0] # avoid error in the following code @@ -414,7 +442,7 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): parameter.register_post_accumulate_grad_hook(__grad_hook) - elif args.fused_optimizer_groups: + elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers for i in range(1, len(optimizers)): optimizers[i] = accelerator.prepare(optimizers[i]) @@ -429,22 +457,46 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} + double_blocks_to_swap = args.double_blocks_to_swap + single_blocks_to_swap = args.single_blocks_to_swap + num_double_blocks = len(flux.double_blocks) + num_single_blocks = len(flux.single_blocks) + for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: - - def optimizer_hook(parameter: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(parameter, args.max_grad_norm) - - i = parameter_optimizer_map[parameter] - optimizer_hooked_count[i] += 1 - if optimizer_hooked_count[i] == num_parameters_per_group[i]: - optimizers[i].step() - optimizers[i].zero_grad(set_to_none=True) - - parameter.register_post_accumulate_grad_hook(optimizer_hook) + block_type, block_idx = block_types_and_indices[opt_idx] + + def create_optimizer_hook(btype, bidx): + def optimizer_hook(parameter: torch.Tensor): + # print(f"optimizer_hook: {btype}, {bidx}") + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + # swap blocks if necessary + if btype == "double" and double_blocks_to_swap: + if bidx >= num_double_blocks - double_blocks_to_swap: + bidx_cuda = double_blocks_to_swap - (num_double_blocks - bidx) + flux.double_blocks[bidx].to("cpu") + flux.double_blocks[bidx_cuda].to(accelerator.device) + # print(f"Move double block {bidx} to cpu and {bidx_cuda} to device") + elif btype == "single" and single_blocks_to_swap: + if bidx >= num_single_blocks - single_blocks_to_swap: + bidx_cuda = single_blocks_to_swap - (num_single_blocks - bidx) + flux.single_blocks[bidx].to("cpu") + flux.single_blocks[bidx_cuda].to(accelerator.device) + # print(f"Move single block {bidx} to cpu and {bidx_cuda} to device") + + return optimizer_hook + + parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx)) parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 @@ -487,6 +539,9 @@ def optimizer_hook(parameter: torch.Tensor): init_kwargs=init_kwargs, ) + if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None: + flux.prepare_block_swap_before_forward() + # For --sample_at_first flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) @@ -502,7 +557,7 @@ def optimizer_hook(parameter: torch.Tensor): for step, batch in enumerate(train_dataloader): current_step.value = global_step - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step with accelerator.accumulate(*training_models): @@ -591,7 +646,7 @@ def optimizer_hook(parameter: torch.Tensor): # backward accelerator.backward(loss) - if not (args.fused_backward_pass or args.fused_optimizer_groups): + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] for m in training_models: @@ -604,7 +659,7 @@ def optimizer_hook(parameter: torch.Tensor): else: # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook lr_scheduler.step() - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: for i in range(1, len(optimizers)): lr_schedulers[i].step() @@ -614,7 +669,7 @@ def optimizer_hook(parameter: torch.Tensor): global_step += 1 flux_train_utils.sample_images( - accelerator, args, epoch, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) # 指定ステップごとにモデルを保存 @@ -673,8 +728,6 @@ def optimizer_hook(parameter: torch.Tensor): is_main_process = accelerator.is_main_process # if is_main_process: flux = accelerator.unwrap_model(flux) - clip_l = accelerator.unwrap_model(clip_l) - t5xxl = accelerator.unwrap_model(t5xxl) accelerator.end_training() @@ -707,13 +760,43 @@ def setup_parser() -> argparse.ArgumentParser: "--fused_optimizer_groups", type=int, default=None, - help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", ) parser.add_argument( "--skip_latents_validity_check", action="store_true", help="skip latents validity check / latentsの正当性チェックをスキップする", ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップする'変換ブロック'(約640MB)の数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) + parser.add_argument( + "--single_blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップする'変換ブロック'(約320MB)の数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", + ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index ed0bc8c7d..3f44068f9 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -4,6 +4,11 @@ from dataclasses import dataclass import math +from typing import Optional + +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() import torch from einops import rearrange @@ -466,6 +471,33 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso # region layers + + +# for cpu_offload_checkpointing + + +def to_cuda(x): + if isinstance(x, torch.Tensor): + return x.cuda() + elif isinstance(x, (list, tuple)): + return [to_cuda(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cuda(v) for k, v in x.items()} + else: + return x + + +def to_cpu(x): + if isinstance(x, torch.Tensor): + return x.cpu() + elif isinstance(x, (list, tuple)): + return [to_cpu(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cpu(v) for k, v in x.items()} + else: + return x + + class EmbedND(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: list[int]): super().__init__() @@ -648,16 +680,15 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: ) self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False - def enable_gradient_checkpointing(self): + def enable_gradient_checkpointing(self, cpu_offload: bool = False): self.gradient_checkpointing = True - # self.img_attn.enable_gradient_checkpointing() - # self.txt_attn.enable_gradient_checkpointing() + self.cpu_offload_checkpointing = cpu_offload def disable_gradient_checkpointing(self): self.gradient_checkpointing = False - # self.img_attn.disable_gradient_checkpointing() - # self.txt_attn.disable_gradient_checkpointing() + self.cpu_offload_checkpointing = False def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: img_mod1, img_mod2 = self.img_mod(vec) @@ -694,11 +725,24 @@ def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[T txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) return img, txt - def forward(self, *args, **kwargs): + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: if self.training and self.gradient_checkpointing: - return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + if not self.cpu_offload_checkpointing: + return checkpoint(self._forward, img, txt, vec, pe, use_reentrant=False) + # cpu offload checkpointing + + def create_custom_forward(func): + def custom_forward(*inputs): + cuda_inputs = to_cuda(inputs) + outputs = func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe) + else: - return self._forward(*args, **kwargs) + return self._forward(img, txt, vec, pe) # def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): # if self.training and self.gradient_checkpointing: @@ -747,12 +791,15 @@ def __init__( self.modulation = Modulation(hidden_size, double=False) self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False - def enable_gradient_checkpointing(self): + def enable_gradient_checkpointing(self, cpu_offload: bool = False): self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload def disable_gradient_checkpointing(self): self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: mod, _ = self.modulation(vec) @@ -768,11 +815,24 @@ def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) return x + mod.gate * output - def forward(self, *args, **kwargs): + def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: if self.training and self.gradient_checkpointing: - return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + if not self.cpu_offload_checkpointing: + return checkpoint(self._forward, x, vec, pe, use_reentrant=False) + + # cpu offload checkpointing + + def create_custom_forward(func): + def custom_forward(*inputs): + cuda_inputs = to_cuda(inputs) + outputs = func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe) else: - return self._forward(*args, **kwargs) + return self._forward(x, vec, pe) # def forward(self, x: Tensor, vec: Tensor, pe: Tensor): # if self.training and self.gradient_checkpointing: @@ -849,6 +909,9 @@ def __init__(self, params: FluxParams): self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.double_blocks_to_swap = None + self.single_blocks_to_swap = None @property def device(self): @@ -858,8 +921,9 @@ def device(self): def dtype(self): return next(self.parameters()).dtype - def enable_gradient_checkpointing(self): + def enable_gradient_checkpointing(self, cpu_offload: bool = False): self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload self.time_in.enable_gradient_checkpointing() self.vector_in.enable_gradient_checkpointing() @@ -867,12 +931,13 @@ def enable_gradient_checkpointing(self): self.guidance_in.enable_gradient_checkpointing() for block in self.double_blocks + self.single_blocks: - block.enable_gradient_checkpointing() + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) - print("FLUX: Gradient checkpointing enabled.") + print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}") def disable_gradient_checkpointing(self): self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False self.time_in.disable_gradient_checkpointing() self.vector_in.disable_gradient_checkpointing() @@ -884,6 +949,24 @@ def disable_gradient_checkpointing(self): print("FLUX: Gradient checkpointing disabled.") + def enable_block_swap(self, double_blocks: Optional[int], single_blocks: Optional[int]): + self.double_blocks_to_swap = double_blocks + self.single_blocks_to_swap = single_blocks + + def prepare_block_swap_before_forward(self): + # move last n blocks to cpu: they are on cuda + if self.double_blocks_to_swap: + for i in range(len(self.double_blocks) - self.double_blocks_to_swap): + self.double_blocks[i].to(self.device) + for i in range(len(self.double_blocks) - self.double_blocks_to_swap, len(self.double_blocks)): + self.double_blocks[i].to("cpu") # , non_blocking=True) + if self.single_blocks_to_swap: + for i in range(len(self.single_blocks) - self.single_blocks_to_swap): + self.single_blocks[i].to(self.device) + for i in range(len(self.single_blocks) - self.single_blocks_to_swap, len(self.single_blocks)): + self.single_blocks[i].to("cpu") # , non_blocking=True) + clean_memory_on_device(self.device) + def forward( self, img: Tensor, @@ -910,14 +993,75 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) - for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + if not self.double_blocks_to_swap: + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + else: + # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning + for block_idx in range(self.double_blocks_to_swap): + block = self.double_blocks[len(self.double_blocks) - self.double_blocks_to_swap + block_idx] + if block.parameters().__next__().device.type != "cpu": + block.to("cpu") # , non_blocking=True) + # print(f"Moved double block {len(self.double_blocks) - self.double_blocks_to_swap + block_idx} to cpu.") + + block = self.double_blocks[block_idx] + if block.parameters().__next__().device.type == "cpu": + block.to(self.device) + # print(f"Moved double block {block_idx} to cuda.") + + to_cpu_block_index = 0 + for block_idx, block in enumerate(self.double_blocks): + # move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda + moving = block_idx >= len(self.double_blocks) - self.double_blocks_to_swap + if moving: + block.to(self.device) # move to cuda + # print(f"Moved double block {block_idx} to cuda.") + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + if moving: + self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) + # print(f"Moved double block {to_cpu_block_index} to cpu.") + to_cpu_block_index += 1 img = torch.cat((txt, img), 1) - for block in self.single_blocks: - img = block(img, vec=vec, pe=pe) + + if not self.single_blocks_to_swap: + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + else: + # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning + for block_idx in range(self.single_blocks_to_swap): + block = self.single_blocks[len(self.single_blocks) - self.single_blocks_to_swap + block_idx] + if block.parameters().__next__().device.type != "cpu": + block.to("cpu") # , non_blocking=True) + # print(f"Moved single block {len(self.single_blocks) - self.single_blocks_to_swap + block_idx} to cpu.") + + block = self.single_blocks[block_idx] + if block.parameters().__next__().device.type == "cpu": + block.to(self.device) + # print(f"Moved single block {block_idx} to cuda.") + + to_cpu_block_index = 0 + for block_idx, block in enumerate(self.single_blocks): + # move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda + moving = block_idx >= len(self.single_blocks) - self.single_blocks_to_swap + if moving: + block.to(self.device) # move to cuda + # print(f"Moved single block {block_idx} to cuda.") + + img = block(img, vec=vec, pe=pe) + + if moving: + self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) + # print(f"Moved single block {to_cpu_block_index} to cpu.") + img = img[:, txt.shape[1] :, ...] + if self.training and self.cpu_offload_checkpointing: + img = img.to(self.device) + vec = vec.to(self.device) + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img From a45048892802dce43e86a7e377ba84e89b51fdf5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 18 Aug 2024 16:56:50 +0900 Subject: [PATCH 084/348] update readme --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 521e82e86..df2a612d7 100644 --- a/README.md +++ b/README.md @@ -9,10 +9,8 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` - Aug 18, 2024: -Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. - +Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr! See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. Aug 17, 2024: Added a script `flux_train.py` to train FLUX.1. The script is experimental and not an optimized version. It needs >28GB VRAM for training. @@ -118,6 +116,8 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t (Combine the command into one line.) +Sample image generation during training is not tested yet. + Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizer`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. `--blockwise_fused_optimizer` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training. From d034032a5dff4a5ee1a108e4f1cec41d8efadab0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 19 Aug 2024 13:08:49 +0900 Subject: [PATCH 085/348] update README fix option name --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index df2a612d7..9a603b281 100644 --- a/README.md +++ b/README.md @@ -105,24 +105,24 @@ Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GP ``` accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py --pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft ---mixed_precision bf16 --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 +--save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 ---dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name test-bf16 +--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name --learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1 --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 ---blockwise_fused_optimizer --double_blocks_to_swap 6 --cpu_offload_checkpointing +--blockwise_fused_optimizers --double_blocks_to_swap 6 --cpu_offload_checkpointing ``` (Combine the command into one line.) Sample image generation during training is not tested yet. -Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizer`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. +Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. -`--blockwise_fused_optimizer` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training. +`--blockwise_fused_optimizers` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training. -`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--blockwise_fused_optimizer`. +`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--blockwise_fused_optimizers`. `--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. From 6e72a799c8f55f148a248693d2c0c3fb1912b04e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 19 Aug 2024 21:55:28 +0900 Subject: [PATCH 086/348] reduce peak VRAM usage by excluding some blocks to cuda --- flux_train.py | 15 +++++++++------ library/flux_models.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/flux_train.py b/flux_train.py index ecb3c7dda..b294ce42a 100644 --- a/flux_train.py +++ b/flux_train.py @@ -251,7 +251,6 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - # if we load to cpu, flux.to(fp8) takes a long time flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") if args.gradient_checkpointing: @@ -259,7 +258,8 @@ def train(args): flux.requires_grad_(True) - if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None: + is_swapping_blocks = args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None + if is_swapping_blocks: # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! logger.info( @@ -412,8 +412,11 @@ def train(args): training_models = [ds_model] else: - # acceleratorがなんかよろしくやってくれるらしい - flux = accelerator.prepare(flux) + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks]) + if is_swapping_blocks: + flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -539,7 +542,7 @@ def optimizer_hook(parameter: torch.Tensor): init_kwargs=init_kwargs, ) - if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None: + if is_swapping_blocks: flux.prepare_block_swap_before_forward() # For --sample_at_first @@ -595,7 +598,7 @@ def optimizer_hook(parameter: torch.Tensor): # get noisy model input and timesteps noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype ) # pack latents and get img_ids diff --git a/library/flux_models.py b/library/flux_models.py index 3f44068f9..11ef647ad 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -953,6 +953,22 @@ def enable_block_swap(self, double_blocks: Optional[int], single_blocks: Optiona self.double_blocks_to_swap = double_blocks self.single_blocks_to_swap = single_blocks + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu + if self.double_blocks_to_swap: + save_double_blocks = self.double_blocks + self.double_blocks = None + if self.single_blocks_to_swap: + save_single_blocks = self.single_blocks + self.single_blocks = None + + self.to(device) + + if self.double_blocks_to_swap: + self.double_blocks = save_double_blocks + if self.single_blocks_to_swap: + self.single_blocks = save_single_blocks + def prepare_block_swap_before_forward(self): # move last n blocks to cpu: they are on cuda if self.double_blocks_to_swap: From 486fe8f70a53166f21f08b1c896bd9ba1e31d7e7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 19 Aug 2024 22:30:24 +0900 Subject: [PATCH 087/348] feat: reduce memory usage and add memory efficient option for model saving --- README.md | 5 +++ flux_train.py | 6 +++ library/flux_train_utils.py | 21 ++++++++--- library/utils.py | 75 ++++++++++++++++++++++++++++++++++++- 4 files changed, 100 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 9a603b281..51e4635bb 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,11 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 19, 2024: +In `flux_train.py`, the memory consumption during model saving is reduced when `--save_precision` is set to the same value as `--mixed_precision` (about 22GB). Please set the same value unless there is a reason. + +An experimental option `--mem_eff_save` is also added. When specified, it can further reduce memory consumption (about 22GB), but since it is a custom implementation, unexpected problems may occur. We do not recommend using it unless you are familiar with the code. + Aug 18, 2024: Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr! See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. diff --git a/flux_train.py b/flux_train.py index b294ce42a..669963856 100644 --- a/flux_train.py +++ b/flux_train.py @@ -759,6 +759,12 @@ def setup_parser() -> argparse.ArgumentParser: add_custom_train_arguments(parser) # TODO remove this from here flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument( + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + parser.add_argument( "--fused_optimizer_groups", type=int, diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 167d61c7e..3f9e8660f 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -20,7 +20,7 @@ init_ipex() -from .utils import setup_logging +from .utils import setup_logging, mem_eff_save_file setup_logging() import logging @@ -409,19 +409,28 @@ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): return model_pred, weighting -def save_models(ckpt_path: str, flux: flux_models.Flux, sai_metadata: Optional[dict], save_dtype: Optional[torch.dtype] = None): +def save_models( + ckpt_path: str, + flux: flux_models.Flux, + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, + use_mem_eff_save: bool = False, +): state_dict = {} def update_sd(prefix, sd): for k, v in sd.items(): key = prefix + k - if save_dtype is not None: + if save_dtype is not None and v.dtype != save_dtype: v = v.detach().clone().to("cpu").to(save_dtype) state_dict[key] = v update_sd("", flux.state_dict()) - save_file(state_dict, ckpt_path, metadata=sai_metadata) + if not use_mem_eff_save: + save_file(state_dict, ckpt_path, metadata=sai_metadata) + else: + mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata) def save_flux_model_on_train_end( @@ -429,7 +438,7 @@ def save_flux_model_on_train_end( ): def sd_saver(ckpt_file, epoch_no, global_step): sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") - save_models(ckpt_file, flux, sai_metadata, save_dtype) + save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save) train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) @@ -448,7 +457,7 @@ def save_flux_model_on_epoch_end_or_stepwise( ): def sd_saver(ckpt_file, epoch_no, global_step): sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") - save_models(ckpt_file, flux, sai_metadata, save_dtype) + save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save) train_util.save_sd_model_on_epoch_end_or_stepwise_common( args, diff --git a/library/utils.py b/library/utils.py index 3037c055d..7de22d5a9 100644 --- a/library/utils.py +++ b/library/utils.py @@ -1,9 +1,12 @@ import logging import sys import threading +from typing import * +import json +import struct + import torch from torchvision import transforms -from typing import * from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput @@ -79,6 +82,76 @@ def setup_logging(args=None, log_level=None, reset=False): logger.info(msg_init) +def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): + """ + memory efficient save file + """ + + _TYPES = { + torch.float64: "F64", + torch.float32: "F32", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int64: "I64", + torch.int32: "I32", + torch.int16: "I16", + torch.int8: "I8", + torch.uint8: "U8", + torch.bool: "BOOL", + getattr(torch, "float8_e5m2", None): "F8_E5M2", + getattr(torch, "float8_e4m3fn", None): "F8_E4M3", + } + _ALIGN = 256 + + def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: + validated = {} + for key, value in metadata.items(): + if not isinstance(key, str): + raise ValueError(f"Metadata key must be a string, got {type(key)}") + if not isinstance(value, str): + print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") + validated[key] = str(value) + else: + validated[key] = value + return validated + + print(f"Using memory efficient save file: {filename}") + + header = {} + offset = 0 + if metadata: + header["__metadata__"] = validate_metadata(metadata) + for k, v in tensors.items(): + if v.numel() == 0: # empty tensor + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} + else: + size = v.numel() * v.element_size() + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} + offset += size + + hjson = json.dumps(header).encode("utf-8") + hjson += b" " * (-(len(hjson) + 8) % _ALIGN) + + with open(filename, "wb") as f: + f.write(struct.pack(" Date: Tue, 20 Aug 2024 08:19:00 +0900 Subject: [PATCH 088/348] Fix debug_dataset to work --- train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train_network.py b/train_network.py index 086b314a5..cab0ec52e 100644 --- a/train_network.py +++ b/train_network.py @@ -313,6 +313,7 @@ def train(self, args): collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: From c62c95e8626bdb727cedc8f037c82ab3a8e66059 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Aug 2024 08:21:01 +0900 Subject: [PATCH 089/348] update about multi-resolution training in FLUX.1 --- README.md | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/README.md b/README.md index 51e4635bb..165eed341 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,13 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 20, 2024: +FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution). + +The script seems to support multi-resolution even in the current version, __if `--cache_latents_to_disk` is not specified__. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. + +We will support multi-resolution caching to disk in the near future. + Aug 19, 2024: In `flux_train.py`, the memory consumption during model saving is reduced when `--save_precision` is set to the same value as `--mixed_precision` (about 22GB). Please set the same value unless there is a reason. @@ -159,6 +166,51 @@ In the case of LoRA models are trained with `bf16`, we are not sure which is bet The script can merge multiple LoRA models. If you want to merge multiple LoRA models, specify `--concat` option to work the merged LoRA model properly. +### FLUX.1 Multi-resolution training + +You can define multiple resolutions in the dataset configuration file. __Caching latents to disk is not supported yet.__ + +The dataset configuration file is like below. You can define multiple resolutions with different batch sizes. The resolutions are defined in the `[[datasets]]` section. The `[[datasets.subsets]]` section is for the dataset directory. Please specify the same directory for each resolution. + +``` +[general] +# define common settings here +flip_aug = true +color_aug = false +keep_tokens_separator= "|||" +shuffle_caption = false +caption_tag_dropout_rate = 0 +caption_extension = ".txt" + +[[datasets]] +# define the first resolution here +batch_size = 2 +enable_bucket = true +resolution = [1024, 1024] + + [[datasets.subsets]] + image_dir = "path/to/image/dir" + num_repeats = 1 + +[[datasets]] +# define the second resolution here +batch_size = 3 +enable_bucket = true +resolution = [768, 768] + + [[datasets.subsets]] + image_dir = "path/to/image/dir" + num_repeats = 1 + +[[datasets]] +# define the third resolution here +batch_size = 4 +enable_bucket = true +resolution = [512, 512] + + [[datasets.subsets]] + image_dir = "path/to/image/dir" + num_repeats = 1 ``` ## SD3 training From 6f6faf9b5a99b7f741f657a06a42f63754e450c0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Aug 2024 19:16:25 +0900 Subject: [PATCH 090/348] fix to work with ai-toolkit LoRA --- networks/flux_merge_lora.py | 163 +++++++++++++++--------------------- 1 file changed, 68 insertions(+), 95 deletions(-) diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index df0ba606a..1ba1f314d 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -7,8 +7,6 @@ from safetensors.torch import load_file, save_file from tqdm import tqdm -import lora_flux as lora_flux -from library import sai_model_spec, train_util from library.utils import setup_logging setup_logging() @@ -16,6 +14,9 @@ logger = logging.getLogger(__name__) +import lora_flux as lora_flux +from library import sai_model_spec, train_util + def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -43,13 +44,11 @@ def save_to_file(file_name, state_dict, dtype, metadata): save_file(state_dict, file_name, metadata=metadata) -def merge_to_flux_model( - loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype -): +def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): logger.info(f"loading keys from FLUX.1 model: {flux_model}") flux_state_dict = load_file(flux_model, device=loading_device) - def create_key_map(n_double_layers, n_single_layers, hidden_size): + def create_key_map(n_double_layers, n_single_layers): key_map = {} for index in range(n_double_layers): prefix_from = f"transformer_blocks.{index}" @@ -60,18 +59,12 @@ def create_key_map(n_double_layers, n_single_layers, hidden_size): qkv_img = f"{prefix_to}.img_attn.qkv.{end}" qkv_txt = f"{prefix_to}.txt_attn.qkv.{end}" - key_map[f"{k}to_q.{end}"] = (qkv_img, (0, 0, hidden_size)) - key_map[f"{k}to_k.{end}"] = (qkv_img, (0, hidden_size, hidden_size)) - key_map[f"{k}to_v.{end}"] = (qkv_img, (0, hidden_size * 2, hidden_size)) - key_map[f"{k}add_q_proj.{end}"] = (qkv_txt, (0, 0, hidden_size)) - key_map[f"{k}add_k_proj.{end}"] = ( - qkv_txt, - (0, hidden_size, hidden_size), - ) - key_map[f"{k}add_v_proj.{end}"] = ( - qkv_txt, - (0, hidden_size * 2, hidden_size), - ) + key_map[f"{k}to_q.{end}"] = qkv_img + key_map[f"{k}to_k.{end}"] = qkv_img + key_map[f"{k}to_v.{end}"] = qkv_img + key_map[f"{k}add_q_proj.{end}"] = qkv_txt + key_map[f"{k}add_k_proj.{end}"] = qkv_txt + key_map[f"{k}add_v_proj.{end}"] = qkv_txt block_map = { "attn.to_out.0.weight": "img_attn.proj.weight", @@ -106,13 +99,10 @@ def create_key_map(n_double_layers, n_single_layers, hidden_size): for end in ("weight", "bias"): k = f"{prefix_from}.attn." qkv = f"{prefix_to}.linear1.{end}" - key_map[f"{k}to_q.{end}"] = (qkv, (0, 0, hidden_size)) - key_map[f"{k}to_k.{end}"] = (qkv, (0, hidden_size, hidden_size)) - key_map[f"{k}to_v.{end}"] = (qkv, (0, hidden_size * 2, hidden_size)) - key_map[f"{prefix_from}.proj_mlp.{end}"] = ( - qkv, - (0, hidden_size * 3, hidden_size * 4), - ) + key_map[f"{k}to_q.{end}"] = qkv + key_map[f"{k}to_k.{end}"] = qkv + key_map[f"{k}to_v.{end}"] = qkv + key_map[f"{prefix_from}.proj_mlp.{end}"] = qkv block_map = { "norm.linear.weight": "modulation.lin.weight", @@ -126,11 +116,14 @@ def create_key_map(n_double_layers, n_single_layers, hidden_size): for k, v in block_map.items(): key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}" + # add as-is keys + values = list([(v if isinstance(v, str) else v[0]) for v in set(key_map.values())]) + values.sort() + key_map.update({v: v for v in values}) + return key_map - key_map = create_key_map( - 18, 1, 2048 - ) # Assuming 18 double layers, 1 single layer, and hidden size of 2048 + key_map = create_key_map(18, 38) # 18 double layers, 38 single layers def find_matching_key(flux_dict, lora_key): lora_key = lora_key.replace("diffusion_model.", "") @@ -159,7 +152,6 @@ def find_matching_key(flux_dict, lora_key): "attn.add_k_proj": "txt_attn.qkv", "attn.add_v_proj": "txt_attn.qkv", } - single_block_map = { "norm.linear": "modulation.lin", "proj_out": "linear2", @@ -168,18 +160,22 @@ def find_matching_key(flux_dict, lora_key): "attn.to_q": "linear1", "attn.to_k": "linear1", "attn.to_v": "linear1", + "proj_mlp": "linear1", } + # same key exists in both single_block_map and double_block_map, so we must care about single/double + # print("lora_key before double_block_map", lora_key) for old, new in double_block_map.items(): - lora_key = lora_key.replace(old, new) - + if "double" in lora_key: + lora_key = lora_key.replace(old, new) + # print("lora_key before single_block_map", lora_key) for old, new in single_block_map.items(): - lora_key = lora_key.replace(old, new) + if "single" in lora_key: + lora_key = lora_key.replace(old, new) + # print("lora_key after mapping", lora_key) if lora_key in key_map: flux_key = key_map[lora_key] - if isinstance(flux_key, tuple): - flux_key = flux_key[0] logger.info(f"Found matching key: {flux_key}") return flux_key @@ -198,16 +194,11 @@ def find_matching_key(flux_dict, lora_key): lora_sd, _ = load_state_dict(model, merge_dtype) logger.info("merging...") - for key in tqdm(lora_sd.keys()): + for key in lora_sd.keys(): if "lora_down" in key or "lora_A" in key: - lora_name = key[ - : key.rfind(".lora_down" if "lora_down" in key else ".lora_A") - ] + lora_name = key[: key.rfind(".lora_down" if "lora_down" in key else ".lora_A")] up_key = key.replace("lora_down", "lora_up").replace("lora_A", "lora_B") - alpha_key = ( - key[: key.index("lora_down" if "lora_down" in key else "lora_A")] - + "alpha" - ) + alpha_key = key[: key.index("lora_down" if "lora_down" in key else "lora_A")] + "alpha" logger.info(f"Processing LoRA key: {lora_name}") flux_key = find_matching_key(flux_state_dict, lora_name) @@ -231,20 +222,35 @@ def find_matching_key(flux_dict, lora_key): up_weight = up_weight.to(working_device, merge_dtype) down_weight = down_weight.to(working_device, merge_dtype) + # print(up_weight.size(), down_weight.size(), weight.size()) + if lora_name.startswith("transformer."): - if "qkv" in flux_key: - hidden_size = weight.size(-1) // 3 + if "qkv" in flux_key or "linear1" in flux_key: # combined qkv or qkv+mlp update = ratio * (up_weight @ down_weight) * scale + # print(update.shape) if "img_attn" in flux_key or "txt_attn" in flux_key: - q, k, v = torch.chunk(weight, 3, dim=-1) + q, k, v = torch.chunk(weight, 3, dim=0) if "to_q" in lora_name or "add_q_proj" in lora_name: q += update.reshape(q.shape) elif "to_k" in lora_name or "add_k_proj" in lora_name: k += update.reshape(k.shape) elif "to_v" in lora_name or "add_v_proj" in lora_name: v += update.reshape(v.shape) - weight = torch.cat([q, k, v], dim=-1) + weight = torch.cat([q, k, v], dim=0) + elif "linear1" in flux_key: + q, k, v = torch.chunk(weight[: int(update.shape[-1] * 3)], 3, dim=0) + mlp = weight[int(update.shape[-1] * 3) :] + # print(q.shape, k.shape, v.shape, mlp.shape) + if "to_q" in lora_name: + q += update.reshape(q.shape) + elif "to_k" in lora_name: + k += update.reshape(k.shape) + elif "to_v" in lora_name: + v += update.reshape(v.shape) + elif "proj_mlp" in lora_name: + mlp += update.reshape(mlp.shape) + weight = torch.cat([q, k, v, mlp], dim=0) else: if len(weight.size()) == 2: weight = weight + ratio * (up_weight @ down_weight) * scale @@ -252,18 +258,11 @@ def find_matching_key(flux_dict, lora_key): weight = ( weight + ratio - * ( - up_weight.squeeze(3).squeeze(2) - @ down_weight.squeeze(3).squeeze(2) - ) - .unsqueeze(2) - .unsqueeze(3) + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale ) else: - conved = torch.nn.functional.conv2d( - down_weight.permute(1, 0, 2, 3), up_weight - ).permute(1, 0, 2, 3) + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) weight = weight + ratio * conved * scale else: if len(weight.size()) == 2: @@ -272,18 +271,11 @@ def find_matching_key(flux_dict, lora_key): weight = ( weight + ratio - * ( - up_weight.squeeze(3).squeeze(2) - @ down_weight.squeeze(3).squeeze(2) - ) - .unsqueeze(2) - .unsqueeze(3) + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale ) else: - conved = torch.nn.functional.conv2d( - down_weight.permute(1, 0, 2, 3), up_weight - ).permute(1, 0, 2, 3) + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) weight = weight + ratio * conved * scale flux_state_dict[flux_key] = weight.to(loading_device, save_dtype) @@ -308,9 +300,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_metadata is not None: if base_model is None: - base_model = lora_metadata.get( - train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None - ) + base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) # get alpha and dim alphas = {} # alpha for current model @@ -336,9 +326,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - logger.info( - f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}" - ) + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") # merge logger.info("merging...") @@ -359,19 +347,14 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio - scale = ( - abs(scale) if "lora_up" in key else scale - ) # マイナスの重みに対応する。 + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 if key in merged_sd: assert ( - merged_sd[key].size() == lora_sd[key].size() - or concat_dim is not None + merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None ), "weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。" if concat_dim is not None: - merged_sd[key] = torch.cat( - [merged_sd[key], lora_sd[key] * scale], dim=concat_dim - ) + merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) else: merged_sd[key] = merged_sd[key] + lora_sd[key] * scale else: @@ -390,9 +373,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): merged_sd[key_up] = merged_sd[key_up][:, perm] logger.info("merged model") - logger.info( - f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}" - ) + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") # check all dims are same dims_list = list(set(base_dims.values())) @@ -411,16 +392,14 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): # build minimum metadata dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" - metadata = train_util.build_minimum_network_metadata( - str(False), base_model, "networks.lora", dims, alphas, None - ) + metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora", dims, alphas, None) return merged_sd, metadata def merge(args): - assert ( - len(args.models) == len(args.ratios) + assert len(args.models) == len( + args.ratios ), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" def str_to_dtype(p): @@ -456,9 +435,7 @@ def str_to_dtype(p): if args.no_metadata: sai_metadata = None else: - merged_from = sai_model_spec.build_merged_from( - [args.flux_model] + args.models - ) + merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( None, @@ -477,15 +454,11 @@ def str_to_dtype(p): save_to_file(args.save_to, state_dict, save_dtype, sai_metadata) else: - state_dict, metadata = merge_lora_models( - args.models, args.ratios, merge_dtype, args.concat, args.shuffle - ) + state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) logger.info("calculating hashes and creating metadata...") - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes( - state_dict, metadata - ) + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash From 9381332020b7089a41eb8d041938f8ba417528d1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Aug 2024 19:32:26 +0900 Subject: [PATCH 091/348] revert merge function add add option to use new func --- README.md | 3 + networks/flux_merge_lora.py | 120 +++++++++++++++++++++++++++--------- 2 files changed, 94 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 165eed341..3f5c4daa5 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 20, 2024 (update 2): +`flux_merge_lora.py` now supports LoRA from AI-toolkit (Diffusers based keys). Specify `--diffusers` option to merge LoRA with Diffusers based keys. Thanks to exveria1015! + Aug 20, 2024: FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution). diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index 1ba1f314d..fd9cc4e3a 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -4,6 +4,7 @@ import time import torch +from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm @@ -45,6 +46,81 @@ def save_to_file(file_name, state_dict, dtype, metadata): def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): + # create module map without loading state_dict + logger.info(f"loading keys from FLUX.1 model: {flux_model}") + lora_name_to_module_key = {} + with safe_open(flux_model, framework="pt", device=loading_device) as flux_file: + keys = list(flux_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") + lora_name_to_module_key[lora_name] = key + + flux_state_dict = load_file(flux_model, device=loading_device) + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU + + logger.info(f"merging...") + for key in tqdm(list(lora_sd.keys())): + if "lora_down" in key: + lora_name = key[: key.rfind(".lora_down")] + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + if lora_name not in lora_name_to_module_key: + logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.") + continue + + down_weight = lora_sd.pop(key) + up_weight = lora_sd.pop(up_key) + + dim = down_weight.size()[0] + alpha = lora_sd.pop(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + module_weight_key = lora_name_to_module_key[lora_name] + if module_weight_key not in flux_state_dict: + weight = flux_file.get_tensor(module_weight_key) + else: + weight = flux_state_dict[module_weight_key] + + weight = weight.to(working_device, merge_dtype) + up_weight = up_weight.to(working_device, merge_dtype) + down_weight = down_weight.to(working_device, merge_dtype) + + # logger.info(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + del up_weight + del down_weight + del weight + + if len(lora_sd) > 0: + logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}") + + return flux_state_dict + + +def merge_to_flux_model_diffusers(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): logger.info(f"loading keys from FLUX.1 model: {flux_model}") flux_state_dict = load_file(flux_model, device=loading_device) @@ -422,15 +498,14 @@ def str_to_dtype(p): os.makedirs(dest_dir) if args.flux_model is not None: - state_dict = merge_to_flux_model( - args.loading_device, - args.working_device, - args.flux_model, - args.models, - args.ratios, - merge_dtype, - save_dtype, - ) + if not args.diffusers: + state_dict = merge_to_flux_model( + args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + ) + else: + state_dict = merge_to_flux_model_diffusers( + args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + ) if args.no_metadata: sai_metadata = None @@ -438,16 +513,7 @@ def str_to_dtype(p): merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - None, - False, - False, - False, - False, - False, - time.time(), - title=title, - merged_from=merged_from, - flux="dev", + None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev" ) logger.info(f"saving FLUX model to: {args.save_to}") @@ -466,16 +532,7 @@ def str_to_dtype(p): merged_from = sai_model_spec.build_merged_from(args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - state_dict, - False, - False, - False, - True, - False, - time.time(), - title=title, - merged_from=merged_from, - flux="dev", + state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" ) metadata.update(sai_metadata) @@ -553,6 +610,11 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="shuffle lora weight./ " + "LoRAの重みをシャッフルする", ) + parser.add_argument( + "--diffusers", + action="store_true", + help="merge Diffusers (?) LoRA models / Diffusers (?) LoRAモデルをマージする", + ) return parser From dbed5126bd1133da832dae31ce73ba6c41afc9d3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Aug 2024 19:33:47 +0900 Subject: [PATCH 092/348] chore: formatting --- networks/flux_merge_lora.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index fd9cc4e3a..d5e82920d 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -113,7 +113,7 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati del up_weight del down_weight del weight - + if len(lora_sd) > 0: logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}") @@ -587,12 +587,7 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="LoRA models to merge: safetensors file / マージするLoRAモデル、safetensorsファイル", ) - parser.add_argument( - "--ratios", - type=float, - nargs="*", - help="ratios for each model / それぞれのLoRAモデルの比率", - ) + parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") parser.add_argument( "--no_metadata", action="store_true", From 6ab48b09d8e46973d5e5fa47baeae3a464d06d04 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Aug 2024 21:39:43 +0900 Subject: [PATCH 093/348] feat: Support multi-resolution training with caching latents to disk --- README.md | 11 +++- library/strategy_base.py | 112 ++++++++++++++++++++++++++------------- library/strategy_flux.py | 11 +++- library/train_util.py | 2 +- 4 files changed, 93 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 3f5c4daa5..1d44c9e58 100644 --- a/README.md +++ b/README.md @@ -9,13 +9,20 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 20, 2024 (update 3): +__Experimental__ The multi-resolution training is now supported with caching latents to disk. + +The cache files now hold latents for multiple resolutions. Since the latents are appended to the current cache file, it is recommended to delete the cache file in advance (if not, the old latents is kept in .npz file). + +See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. + Aug 20, 2024 (update 2): `flux_merge_lora.py` now supports LoRA from AI-toolkit (Diffusers based keys). Specify `--diffusers` option to merge LoRA with Diffusers based keys. Thanks to exveria1015! Aug 20, 2024: FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution). -The script seems to support multi-resolution even in the current version, __if `--cache_latents_to_disk` is not specified__. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. +The script seems to support multi-resolution even in the current version, ~~if `--cache_latents_to_disk` is not specified~~ -> `--cache_latents_to_disk` is now supported for multi-resolution training. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. We will support multi-resolution caching to disk in the near future. @@ -171,7 +178,7 @@ The script can merge multiple LoRA models. If you want to merge multiple LoRA mo ### FLUX.1 Multi-resolution training -You can define multiple resolutions in the dataset configuration file. __Caching latents to disk is not supported yet.__ +You can define multiple resolutions in the dataset configuration file. The dataset configuration file is like below. You can define multiple resolutions with different batch sizes. The resolutions are defined in the `[[datasets]]` section. The `[[datasets.subsets]]` section is for the dataset directory. Please specify the same directory for each resolution. diff --git a/library/strategy_base.py b/library/strategy_base.py index a99a08290..e7d3a97ef 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -219,7 +219,13 @@ def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mas raise NotImplementedError def _default_is_disk_cached_latents_expected( - self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + self, + latents_stride: int, + bucket_reso: Tuple[int, int], + npz_path: str, + flip_aug: bool, + alpha_mask: bool, + multi_resolution: bool = False, ): if not self.cache_to_disk: return False @@ -230,25 +236,17 @@ def _default_is_disk_cached_latents_expected( expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + # e.g. "_32x64", HxW + key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "" + try: npz = np.load(npz_path) - if npz["latents"].shape[1:3] != expected_latents_size: + if "latents" + key_reso_suffix not in npz: + return False + if flip_aug and "latents_flipped" + key_reso_suffix not in npz: + return False + if alpha_mask and "alpha_mask" + key_reso_suffix not in npz: return False - - if flip_aug: - if "latents_flipped" not in npz: - return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: - return False - - if alpha_mask: - if "alpha_mask" not in npz: - return False - if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]): - return False - else: - if "alpha_mask" in npz: - return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -257,7 +255,15 @@ def _default_is_disk_cached_latents_expected( # TODO remove circular dependency for ImageInfo def _default_cache_batch_latents( - self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool + self, + encode_by_vae, + vae_device, + vae_dtype, + image_infos: List, + flip_aug: bool, + alpha_mask: bool, + random_crop: bool, + multi_resolution: bool = False, ): """ Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. @@ -287,8 +293,13 @@ def _default_cache_batch_latents( original_size = original_sizes[i] crop_ltrb = crop_ltrbs[i] + latents_size = latents.shape[1:3] # H, W + key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW + if self.cache_to_disk: - self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask) + self.save_latents_to_disk( + info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix + ) else: info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb @@ -298,31 +309,56 @@ def _default_cache_batch_latents( info.alpha_mask = alpha_mask def load_latents_from_disk( - self, npz_path: str + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + """ + for SD/SDXL/SD3.0 + """ + return self._default_load_latents_from_disk(None, npz_path, bucket_reso) + + def _default_load_latents_from_disk( + self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + if latents_stride is None: + key_reso_suffix = "" + else: + latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW + npz = np.load(npz_path) - if "latents" not in npz: - raise ValueError(f"error: npz is old format. please re-generate {npz_path}") - - latents = npz["latents"] - original_size = npz["original_size"].tolist() - crop_ltrb = npz["crop_ltrb"].tolist() - flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None + if "latents" + key_reso_suffix not in npz: + raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") + + latents = npz["latents" + key_reso_suffix] + original_size = npz["original_size" + key_reso_suffix].tolist() + crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() + flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None + alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None return latents, original_size, crop_ltrb, flipped_latents, alpha_mask def save_latents_to_disk( - self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None + self, + npz_path, + latents_tensor, + original_size, + crop_ltrb, + flipped_latents_tensor=None, + alpha_mask=None, + key_reso_suffix="", ): kwargs = {} + + if os.path.exists(npz_path): + # load existing npz and update it + npz = np.load(npz_path) + for key in npz.files: + kwargs[key] = npz[key] + + kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy() + kwargs["original_size" + key_reso_suffix] = np.array(original_size) + kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb) if flipped_latents_tensor is not None: - kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() + kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy() if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() - np.savez( - npz_path, - latents=latents_tensor.float().cpu().numpy(), - original_size=np.array(original_size), - crop_ltrb=np.array(crop_ltrb), - **kwargs, - ) + kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy() + np.savez(npz_path, **kwargs) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 3880a1e1b..5c620f3d6 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -200,7 +200,12 @@ def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, True) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -208,7 +213,9 @@ def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask vae_device = vae.device vae_dtype = vae.dtype - self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + self._default_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, True + ) if not train_util.HIGH_VRAM: train_util.clean_memory_on_device(vae.device) diff --git a/library/train_util.py b/library/train_util.py index f4ac8740a..8929c192f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1381,7 +1381,7 @@ def __getitem__(self, index): image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( - self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz) + self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso) ) if flipped: latents = flipped_latents From 7e459c00b2e142e40a9452341934c2eb9f70a172 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 21 Aug 2024 08:02:33 +0900 Subject: [PATCH 094/348] Update T5 attention mask handling in FLUX --- README.md | 3 +++ flux_minimal_inference.py | 33 +++++++++++++++++++----- flux_train.py | 6 ++++- flux_train_network.py | 13 +++++----- library/flux_models.py | 51 +++++++++++++++++++++---------------- library/flux_train_utils.py | 20 ++++++++++++--- library/strategy_flux.py | 25 ++++++++++-------- 7 files changed, 101 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index 1d44c9e58..43edbbed6 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 21, 2024: +The specification of `--apply_t5_attn_mask` has been changed. Previously, the T5 output was zero-padded, but now, two steps are taken: "1. Apply mask when encoding T5" and "2. Apply mask in the attention of Double Block". Fine tuning, LoRA training, and inference in `flux_mini_inference.py` have been changed. + Aug 20, 2024 (update 3): __Experimental__ The multi-resolution training is now supported with caching latents to disk. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index b09f63808..5b8aa2506 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -70,12 +70,22 @@ def denoise( vec: torch.Tensor, timesteps: list[float], guidance: float = 4.0, + t5_attn_mask: Optional[torch.Tensor] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) - pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) img = img + (t_prev - t_curr) * pred @@ -92,6 +102,7 @@ def do_sample( txt_ids: torch.Tensor, num_steps: int, guidance: float, + t5_attn_mask: Optional[torch.Tensor], is_schnell: bool, device: torch.device, flux_dtype: torch.dtype, @@ -101,10 +112,14 @@ def do_sample( # denoise initial noise if accelerator: with accelerator.autocast(), torch.no_grad(): - x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance) + x = denoise( + model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + ) else: with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad(): - x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance) + x = denoise( + model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + ) return x @@ -156,14 +171,14 @@ def generate_image( clip_l.to(clip_l_dtype) t5xxl.to(t5xxl_dtype) with accelerator.autocast(): - _, t5_out, txt_ids = encoding_strategy.encode_tokens( + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask ) else: with torch.autocast(device_type=device.type, dtype=clip_l_dtype): - l_pooled, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): - _, t5_out, txt_ids = encoding_strategy.encode_tokens( + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask ) @@ -186,7 +201,11 @@ def generate_image( steps = 4 if is_schnell else 50 img_ids = img_ids.to(device) - x = do_sample(accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, is_schnell, device, flux_dtype) + t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None + + x = do_sample( + accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, t5_attn_mask, is_schnell, device, flux_dtype + ) if args.offload: model = model.cpu() # del model diff --git a/flux_train.py b/flux_train.py index 669963856..ecb8a1086 100644 --- a/flux_train.py +++ b/flux_train.py @@ -610,7 +610,10 @@ def optimizer_hook(parameter: torch.Tensor): guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) # call model - l_pooled, t5_out, txt_ids = text_encoder_conds + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + with accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = flux( @@ -621,6 +624,7 @@ def optimizer_hook(parameter: torch.Tensor): y=l_pooled, timesteps=timesteps / 1000, guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, ) # unpack latents diff --git a/flux_train_network.py b/flux_train_network.py index 002252c87..49bd270c7 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -233,11 +233,11 @@ def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.Fl self.flux_lower = flux_lower self.target_device = device - def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None): + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): self.flux_lower.to("cpu") clean_memory_on_device(self.target_device) self.flux_upper.to(self.target_device) - img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance) + img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask) self.flux_upper.to("cpu") clean_memory_on_device(self.target_device) self.flux_lower.to(self.target_device) @@ -300,10 +300,9 @@ def get_noise_pred_and_target( guidance_vec.requires_grad_(True) # Predict the noise residual - l_pooled, t5_out, txt_ids = text_encoder_conds - # print( - # f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}" - # ) + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None if not args.split_mode: # normal forward @@ -317,6 +316,7 @@ def get_noise_pred_and_target( y=l_pooled, timesteps=timesteps / 1000, guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, ) else: # split forward to reduce memory usage @@ -337,6 +337,7 @@ def get_noise_pred_and_target( y=l_pooled, timesteps=timesteps / 1000, guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, ) # move flux upper back to cpu, and then move flux lower to gpu diff --git a/library/flux_models.py b/library/flux_models.py index 11ef647ad..6f28da603 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -440,10 +440,10 @@ class ModelSpec: # region math -def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: q, k = apply_rope(q, k, pe) - x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) x = rearrange(x, "B H L D -> B L (H D)") return x @@ -607,11 +607,7 @@ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): self.norm = QKNorm(head_dim) self.proj = nn.Linear(dim, dim) - # self.gradient_checkpointing = False - - # def enable_gradient_checkpointing(self): - # self.gradient_checkpointing = True - + # this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly def forward(self, x: Tensor, pe: Tensor) -> Tensor: qkv = self.qkv(x) q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) @@ -620,12 +616,6 @@ def forward(self, x: Tensor, pe: Tensor) -> Tensor: x = self.proj(x) return x - # def forward(self, *args, **kwargs): - # if self.training and self.gradient_checkpointing: - # return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) - # else: - # return self._forward(*args, **kwargs) - @dataclass class ModulationOut: @@ -690,7 +680,9 @@ def disable_gradient_checkpointing(self): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False - def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + def _forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) @@ -713,7 +705,18 @@ def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[T k = torch.cat((txt_k, img_k), dim=2) v = torch.cat((txt_v, img_v), dim=2) - attn = attention(q, k, v, pe=pe) + # make attention mask if not None + attn_mask = None + if txt_attention_mask is not None: + attn_mask = txt_attention_mask # b, seq_len + attn_mask = torch.cat( + (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1]).to(attn_mask.device)), dim=1 + ) # b, seq_len + img_len + + # broadcast attn_mask to all heads + attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img blocks @@ -725,10 +728,12 @@ def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[T txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) return img, txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + def forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: if self.training and self.gradient_checkpointing: if not self.cpu_offload_checkpointing: - return checkpoint(self._forward, img, txt, vec, pe, use_reentrant=False) + return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False) # cpu offload checkpointing def create_custom_forward(func): @@ -739,10 +744,10 @@ def custom_forward(*inputs): return custom_forward - return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe) + return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask) else: - return self._forward(img, txt, vec, pe) + return self._forward(img, txt, vec, pe, txt_attention_mask) # def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): # if self.training and self.gradient_checkpointing: @@ -992,6 +997,7 @@ def forward( timesteps: Tensor, y: Tensor, guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -1011,7 +1017,7 @@ def forward( if not self.double_blocks_to_swap: for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning for block_idx in range(self.double_blocks_to_swap): @@ -1033,7 +1039,7 @@ def forward( block.to(self.device) # move to cuda # print(f"Moved double block {block_idx} to cuda.") - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) if moving: self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) @@ -1164,6 +1170,7 @@ def forward( timesteps: Tensor, y: Tensor, guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -1182,7 +1189,7 @@ def forward( pe = self.pe_embedder(ids) for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) return img, txt, vec, pe diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 3f9e8660f..1d3f80d72 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -190,9 +190,10 @@ def sample_image_inference( te_outputs = sample_prompts_te_outputs[prompt] else: tokens_and_masks = tokenize_strategy.tokenize(prompt) + # strategy has apply_t5_attn_mask option te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) - l_pooled, t5_out, txt_ids = te_outputs + l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs # sample image weight_dtype = ae.dtype # TOFO give dtype as argument @@ -208,9 +209,10 @@ def sample_image_inference( ) timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) + t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None with accelerator.autocast(), torch.no_grad(): - x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale) + x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask) x = x.float() x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) @@ -289,12 +291,22 @@ def denoise( vec: torch.Tensor, timesteps: list[float], guidance: float = 4.0, + t5_attn_mask: Optional[torch.Tensor] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) - pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) img = img + (t_prev - t_curr) * pred @@ -498,7 +510,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--apply_t5_attn_mask", action="store_true", - help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する", ) parser.add_argument( "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 5c620f3d6..737af390a 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -64,22 +64,25 @@ def encode_tokens( l_tokens, t5_tokens = tokens[:2] t5_attn_mask = tokens[2] if len(tokens) > 2 else None + # clip_l is None when using T5 only if clip_l is not None and l_tokens is not None: l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"] else: l_pooled = None + # t5xxl is None when using CLIP only if t5xxl is not None and t5_tokens is not None: # t5_out is [b, max length, 4096] - t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True) - if apply_t5_attn_mask: - t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) + attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device) + t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True) + # if zero_pad_t5_output: + # t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device) else: t5_out = None txt_ids = None - return [l_pooled, t5_out, txt_ids] + return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): @@ -115,6 +118,8 @@ def is_disk_cached_outputs_expected(self, npz_path: str): return False if "txt_ids" not in npz: return False + if "t5_attn_mask" not in npz: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -129,12 +134,12 @@ def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: l_pooled = data["l_pooled"] t5_out = data["t5_out"] txt_ids = data["txt_ids"] + t5_attn_mask = data["t5_attn_mask"] if self.apply_t5_attn_mask: - t5_attn_mask = data["t5_attn_mask"] t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) - return [l_pooled, t5_out, txt_ids] + return [l_pooled, t5_out, txt_ids, t5_attn_mask] def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List @@ -145,7 +150,7 @@ def cache_batch_outputs( tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): # attn_mask is not applied when caching to disk: it is applied when loading from disk - l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens( + l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens( tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk ) @@ -159,15 +164,15 @@ def cache_batch_outputs( l_pooled = l_pooled.cpu().numpy() t5_out = t5_out.cpu().numpy() txt_ids = txt_ids.cpu().numpy() + t5_attn_mask = tokens_and_masks[2].cpu().numpy() for i, info in enumerate(infos): l_pooled_i = l_pooled[i] t5_out_i = t5_out[i] txt_ids_i = txt_ids[i] + t5_attn_mask_i = t5_attn_mask[i] if self.cache_to_disk: - t5_attn_mask = tokens_and_masks[2] - t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() np.savez( info.text_encoder_outputs_npz, l_pooled=l_pooled_i, @@ -176,7 +181,7 @@ def cache_batch_outputs( t5_attn_mask=t5_attn_mask_i, ) else: - info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i) + info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) class FluxLatentsCachingStrategy(LatentsCachingStrategy): From e17c42cb0de8a1303a607ecc75af092dc12dc272 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 21 Aug 2024 12:28:45 +0900 Subject: [PATCH 095/348] Add BFL/Diffusers LoRA converter #1467 #1458 #1483 --- networks/convert_flux_lora.py | 403 ++++++++++++++++++++++++++++++++++ 1 file changed, 403 insertions(+) create mode 100644 networks/convert_flux_lora.py diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py new file mode 100644 index 000000000..dd962ebfe --- /dev/null +++ b/networks/convert_flux_lora.py @@ -0,0 +1,403 @@ +# convert key mapping and data format from some LoRA format to another +""" +Original LoRA format: Based on Black Forest Labs, QKV and MLP are unified into one module +alpha is scalar for each LoRA module + +0 to 18 +lora_unet_double_blocks_0_img_attn_proj.alpha torch.Size([]) +lora_unet_double_blocks_0_img_attn_proj.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_attn_proj.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_img_attn_qkv.alpha torch.Size([]) +lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight torch.Size([9216, 4]) +lora_unet_double_blocks_0_img_mlp_0.alpha torch.Size([]) +lora_unet_double_blocks_0_img_mlp_0.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_mlp_0.lora_up.weight torch.Size([12288, 4]) +lora_unet_double_blocks_0_img_mlp_2.alpha torch.Size([]) +lora_unet_double_blocks_0_img_mlp_2.lora_down.weight torch.Size([4, 12288]) +lora_unet_double_blocks_0_img_mlp_2.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_img_mod_lin.alpha torch.Size([]) +lora_unet_double_blocks_0_img_mod_lin.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_mod_lin.lora_up.weight torch.Size([18432, 4]) +lora_unet_double_blocks_0_txt_attn_proj.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_attn_proj.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_attn_proj.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_txt_attn_qkv.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_attn_qkv.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_attn_qkv.lora_up.weight torch.Size([9216, 4]) +lora_unet_double_blocks_0_txt_mlp_0.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_mlp_0.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_mlp_0.lora_up.weight torch.Size([12288, 4]) +lora_unet_double_blocks_0_txt_mlp_2.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_mlp_2.lora_down.weight torch.Size([4, 12288]) +lora_unet_double_blocks_0_txt_mlp_2.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_txt_mod_lin.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_mod_lin.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_mod_lin.lora_up.weight torch.Size([18432, 4]) + +0 to 37 +lora_unet_single_blocks_0_linear1.alpha torch.Size([]) +lora_unet_single_blocks_0_linear1.lora_down.weight torch.Size([4, 3072]) +lora_unet_single_blocks_0_linear1.lora_up.weight torch.Size([21504, 4]) +lora_unet_single_blocks_0_linear2.alpha torch.Size([]) +lora_unet_single_blocks_0_linear2.lora_down.weight torch.Size([4, 15360]) +lora_unet_single_blocks_0_linear2.lora_up.weight torch.Size([3072, 4]) +lora_unet_single_blocks_0_modulation_lin.alpha torch.Size([]) +lora_unet_single_blocks_0_modulation_lin.lora_down.weight torch.Size([4, 3072]) +lora_unet_single_blocks_0_modulation_lin.lora_up.weight torch.Size([9216, 4]) +""" +""" +ai-toolkit: Based on Diffusers, QKV and MLP are separated into 3 modules. +A is down, B is up. No alpha for each LoRA module. + +0 to 18 +transformer.transformer_blocks.0.attn.add_k_proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.add_k_proj.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.add_v_proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.add_v_proj.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_add_out.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_add_out.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_out.0.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_out.0.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.ff.net.0.proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.ff.net.0.proj.lora_B.weight torch.Size([12288, 16]) +transformer.transformer_blocks.0.ff.net.2.lora_A.weight torch.Size([16, 12288]) +transformer.transformer_blocks.0.ff.net.2.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.ff_context.net.0.proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.ff_context.net.0.proj.lora_B.weight torch.Size([12288, 16]) +transformer.transformer_blocks.0.ff_context.net.2.lora_A.weight torch.Size([16, 12288]) +transformer.transformer_blocks.0.ff_context.net.2.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.norm1.linear.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.norm1.linear.lora_B.weight torch.Size([18432, 16]) +transformer.transformer_blocks.0.norm1_context.linear.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.norm1_context.linear.lora_B.weight torch.Size([18432, 16]) + +0 to 37 +transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([3072, 16]) +transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([3072, 16]) +transformer.single_transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([3072, 16]) +transformer.single_transformer_blocks.0.norm.linear.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.norm.linear.lora_B.weight torch.Size([9216, 16]) +transformer.single_transformer_blocks.0.proj_mlp.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.proj_mlp.lora_B.weight torch.Size([12288, 16]) +transformer.single_transformer_blocks.0.proj_out.lora_A.weight torch.Size([16, 15360]) +transformer.single_transformer_blocks.0.proj_out.lora_B.weight torch.Size([3072, 16]) +""" +""" +xlabs: Unknown format. +0 to 18 +double_blocks.0.processor.proj_lora1.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.proj_lora1.up.weight torch.Size([3072, 16]) +double_blocks.0.processor.proj_lora2.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.proj_lora2.up.weight torch.Size([3072, 16]) +double_blocks.0.processor.qkv_lora1.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.qkv_lora1.up.weight torch.Size([9216, 16]) +double_blocks.0.processor.qkv_lora2.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.qkv_lora2.up.weight torch.Size([9216, 16]) +""" + + +import argparse +from safetensors.torch import save_file +from safetensors import safe_open +import torch + + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def convert_to_sd_scripts(sds_sd, ait_sd, sds_key, ait_key): + ait_down_key = ait_key + ".lora_A.weight" + if ait_down_key not in ait_sd: + return + ait_up_key = ait_key + ".lora_B.weight" + + down_weight = ait_sd.pop(ait_down_key) + sds_sd[sds_key + ".lora_down.weight"] = down_weight + sds_sd[sds_key + ".lora_up.weight"] = ait_sd.pop(ait_up_key) + rank = down_weight.shape[0] + sds_sd[sds_key + ".alpha"] = torch.scalar_tensor(rank, dtype=down_weight.dtype, device=down_weight.device) + + +def convert_to_sd_scripts_cat(sds_sd, ait_sd, sds_key, ait_keys): + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + if ait_down_keys[0] not in ait_sd: + return + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + + down_weights = [ait_sd.pop(k) for k in ait_down_keys] + up_weights = [ait_sd.pop(k) for k in ait_up_keys] + + # lora_down is concatenated along dim=0, so rank is multiplied by the number of splits + rank = down_weights[0].shape[0] + num_splits = len(ait_keys) + sds_sd[sds_key + ".lora_down.weight"] = torch.cat(down_weights, dim=0) + + merged_up_weights = torch.zeros( + (sum(w.shape[0] for w in up_weights), rank * num_splits), + dtype=up_weights[0].dtype, + device=up_weights[0].device, + ) + + i = 0 + for j, up_weight in enumerate(up_weights): + merged_up_weights[i : i + up_weight.shape[0], j * rank : (j + 1) * rank] = up_weight + i += up_weight.shape[0] + + sds_sd[sds_key + ".lora_up.weight"] = merged_up_weights + + # set alpha to new_rank + new_rank = rank * num_splits + sds_sd[sds_key + ".alpha"] = torch.scalar_tensor(new_rank, dtype=down_weights[0].dtype, device=down_weights[0].device) + + +def convert_ai_toolkit_to_sd_scripts(ait_sd): + sds_sd = {} + for i in range(19): + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0" + ) + convert_to_sd_scripts_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out" + ) + convert_to_sd_scripts_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear" + ) + + for i in range(38): + convert_to_sd_scripts_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear" + ) + + if len(ait_sd) > 0: + logger.warning(f"Unsuppored keys for sd-scripts: {ait_sd.keys()}") + return sds_sd + + +def convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + print(f"rank: {rank}, alpha: {alpha}, scale: {scale}") + + # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + # print(f"scale: {scale}, scale_down: {scale_down}, scale_up: {scale_up}") + + ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down + ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up + + +def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + alpha = sds_sd.pop(sds_key + ".alpha") + scale = alpha / rank + + # calculate scale_down and scale_up + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + + num_splits = len(ait_keys) + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + + # down_weight is copied to each split + ait_sd.update({k: down_weight * scale_down for k in ait_down_keys}) + + # calculate dims if not provided + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # up_weight is split to each split + ait_sd.update({k: v * scale_up for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) + + +def convert_sd_scripts_to_ai_toolkit(sds_sd): + ait_sd = {} + for i in range(19): + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0" + ) + convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out" + ) + convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear" + ) + + for i in range(38): + convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + dims=[3072, 3072, 3072, 12288], + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear" + ) + + if len(sds_sd) > 0: + logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") + return ait_sd + + +def main(args): + # load source safetensors + logger.info(f"Loading source file {args.src_path}") + state_dict = {} + with safe_open(args.src_path, framework="pt") as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + + logger.info(f"Converting {args.src} to {args.dst} format") + if args.src == "ai-toolkit" and args.dst == "sd-scripts": + state_dict = convert_ai_toolkit_to_sd_scripts(state_dict) + elif args.src == "sd-scripts" and args.dst == "ai-toolkit": + state_dict = convert_sd_scripts_to_ai_toolkit(state_dict) + else: + raise NotImplementedError(f"Conversion from {args.src} to {args.dst} is not supported") + + # save destination safetensors + logger.info(f"Saving destination file {args.dst_path}") + save_file(state_dict, args.dst_path, metadata=metadata) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert LoRA format") + parser.add_argument("--src", type=str, default="ai-toolkit", help="source format, ai-toolkit or sd-scripts") + parser.add_argument("--dst", type=str, default="sd-scripts", help="destination format, ai-toolkit or sd-scripts") + parser.add_argument("--src_path", type=str, default=None, help="source path") + parser.add_argument("--dst_path", type=str, default=None, help="destination path") + args = parser.parse_args() + main(args) From 2b07a92c8d970a8538a47dd1bcad3122da4e195a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 21 Aug 2024 12:30:23 +0900 Subject: [PATCH 096/348] Fix error in applying mask in Attention and add LoRA converter script --- README.md | 6 ++++++ library/flux_models.py | 5 +++-- networks/convert_flux_lora.py | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 43edbbed6..f4056851f 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,12 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 21, 2024 (update 2): +Fixed an error in applying mask in Attention. The attention mask was float, but it should be bool. + +Added a script `convert_flux_lora.py` to convert LoRA between sd-scripts format (BFL-based) and AI-toolkit format (Diffusers-based). See `--help` for details. BFL-based LoRA has a large module, so converting it to Diffusers format may reduce temporary memory usage in the inference environment. Note that re-conversion will increase the size of LoRA. + + Aug 21, 2024: The specification of `--apply_t5_attn_mask` has been changed. Previously, the T5 output was zero-padded, but now, two steps are taken: "1. Apply mask when encoding T5" and "2. Apply mask in the attention of Double Block". Fine tuning, LoRA training, and inference in `flux_mini_inference.py` have been changed. diff --git a/library/flux_models.py b/library/flux_models.py index 6f28da603..e38119cd7 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -708,9 +708,10 @@ def _forward( # make attention mask if not None attn_mask = None if txt_attention_mask is not None: - attn_mask = txt_attention_mask # b, seq_len + # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask + attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len attn_mask = torch.cat( - (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1]).to(attn_mask.device)), dim=1 + (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1 ) # b, seq_len + img_len # broadcast attn_mask to all heads diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py index dd962ebfe..e9743534d 100644 --- a/networks/convert_flux_lora.py +++ b/networks/convert_flux_lora.py @@ -248,7 +248,7 @@ def convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): rank = down_weight.shape[0] alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here - print(f"rank: {rank}, alpha: {alpha}, scale: {scale}") + # print(f"rank: {rank}, alpha: {alpha}, scale: {scale}") # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 scale_down = scale From e1cd19c0c0ef55709e8eb1e5babe25045f65031f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 21 Aug 2024 21:04:10 +0900 Subject: [PATCH 097/348] add stochastic rounding, fix single block --- README.md | 19 ++++++-- flux_train.py | 95 ++++++++++++++++++++++++++++++++++---- library/adafactor_fused.py | 36 ++++++++++++++- library/flux_models.py | 1 + 4 files changed, 136 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index f4056851f..45349ba38 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,15 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 21, 2024 (update 3): +- There is a bug that `--full_bf16` option is enabled even if it is not specified in `flux_train.py`. The bug will be fixed sooner. __Please specify the `--full_bf16` option explicitly, especially when training with 24GB VRAM.__ +- Stochastic rounding is now implemented when `--fused_backward_pass` is specified. The implementation is +based on the code provided by 2kpr. Thank you so much! + - With this change, `--fused_backward_pass` is recommended over `--blockwise_fused_optimizers` when `--full_bf16` is specified. + - Please note that `--fused_backward_pass` is only supported with Adafactor. +- The sample command in [FLUX.1 fine-tuning](#flux1-fine-tuning) is updated to reflect these changes. +- Fixed `--single_blocks_to_swap` is not working in `flux_train.py`. + Aug 21, 2024 (update 2): Fixed an error in applying mask in Attention. The attention mask was float, but it should be bool. @@ -142,7 +151,7 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1 --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 ---blockwise_fused_optimizers --double_blocks_to_swap 6 --cpu_offload_checkpointing +--fused_backward_pass --double_blocks_to_swap 6 --cpu_offload_checkpointing --full_bf16 ``` (Combine the command into one line.) @@ -151,9 +160,13 @@ Sample image generation during training is not tested yet. Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. -`--blockwise_fused_optimizers` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training. +`--full_bf16` enables the training with bf16 (weights and gradients). + +`--fused_backward_pass` enables the fusing of the optimizer step into the backward pass for each parameter. This reduces the memory usage during training. Only Adafactor optimizer is supported for now. Stochastic rounding is also enabled when `--fused_backward_pass` and `--full_bf16` are specified. + +`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. -`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--blockwise_fused_optimizers`. +`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. `--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. diff --git a/flux_train.py b/flux_train.py index ecb8a1086..bcf4b9564 100644 --- a/flux_train.py +++ b/flux_train.py @@ -277,7 +277,10 @@ def train(args): training_models = [] params_to_optimize = [] training_models.append(flux) - params_to_optimize.append({"params": list(flux.parameters()), "lr": args.learning_rate}) + name_and_params = list(flux.named_parameters()) + # single param group for now + params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate}) + param_names = [[n for n, _ in name_and_params]] # calculate number of trainable parameters n_params = 0 @@ -433,17 +436,89 @@ def train(args): import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - for param_group in optimizer.param_groups: - for parameter in param_group["params"]: - if parameter.requires_grad: - def __grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None + double_blocks_to_swap = args.double_blocks_to_swap + single_blocks_to_swap = args.single_blocks_to_swap + num_double_blocks = len(flux.double_blocks) + num_single_blocks = len(flux.single_blocks) + handled_double_block_indices = set() + handled_single_block_indices = set() - parameter.register_post_accumulate_grad_hook(__grad_hook) + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + grad_hook = None + + if double_blocks_to_swap: + if param_name.startswith("double_blocks"): + block_idx = int(param_name.split(".")[1]) + if ( + block_idx not in handled_double_block_indices + and block_idx >= (num_double_blocks - double_blocks_to_swap) - 1 + and block_idx < num_double_blocks - 1 + ): + # swap next (already backpropagated) block + handled_double_block_indices.add(block_idx) + block_idx_cpu = block_idx + 1 + block_idx_cuda = double_blocks_to_swap - (num_double_blocks - block_idx_cpu) + + # create swap hook + def create_double_swap_grad_hook(bidx, bidx_cuda): + def __grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + # swap blocks if necessary + flux.double_blocks[bidx].to("cpu") + flux.double_blocks[bidx_cuda].to(accelerator.device) + # print(f"Move double block {bidx} to cpu and {bidx_cuda} to device") + + return __grad_hook + + grad_hook = create_double_swap_grad_hook(block_idx_cpu, block_idx_cuda) + if single_blocks_to_swap: + if param_name.startswith("single_blocks"): + block_idx = int(param_name.split(".")[1]) + if ( + block_idx not in handled_single_block_indices + and block_idx >= (num_single_blocks - single_blocks_to_swap) - 1 + and block_idx < num_single_blocks - 1 + ): + handled_single_block_indices.add(block_idx) + block_idx_cpu = block_idx + 1 + block_idx_cuda = single_blocks_to_swap - (num_single_blocks - block_idx_cpu) + # print(param_name, block_idx_cpu, block_idx_cuda) + + # create swap hook + def create_single_swap_grad_hook(bidx, bidx_cuda): + def __grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + # swap blocks if necessary + flux.single_blocks[bidx].to("cpu") + flux.single_blocks[bidx_cuda].to(accelerator.device) + # print(f"Move single block {bidx} to cpu and {bidx_cuda} to device") + + return __grad_hook + + grad_hook = create_single_swap_grad_hook(block_idx_cpu, block_idx_cuda) + + if grad_hook is None: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + grad_hook = __grad_hook + + parameter.register_post_accumulate_grad_hook(grad_hook) elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py index bdfc32ced..b5afa236b 100644 --- a/library/adafactor_fused.py +++ b/library/adafactor_fused.py @@ -2,6 +2,32 @@ import torch from transformers import Adafactor +# stochastic rounding for bfloat16 +# The implementation was provided by 2kpr. Thank you very much! + +def copy_stochastic_(target: torch.Tensor, source: torch.Tensor): + """ + copies source into target using stochastic rounding + + Args: + target: the target tensor with dtype=bfloat16 + source: the target tensor with dtype=float32 + """ + # create a random 16 bit integer + result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16)) + + # add the random number to the lower 16 bit of the mantissa + result.add_(source.view(dtype=torch.int32)) + + # mask off the lower 16 bit of the mantissa + result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 + + # copy the higher 16 bit into the target tensor + target.copy_(result.view(dtype=torch.float32)) + + del result + + @torch.no_grad() def adafactor_step_param(self, p, group): if p.grad is None: @@ -48,7 +74,7 @@ def adafactor_step_param(self, p, group): lr = Adafactor._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) - update = (grad ** 2) + group["eps"][0] + update = (grad**2) + group["eps"][0] if factored: exp_avg_sq_row = state["exp_avg_sq_row"] exp_avg_sq_col = state["exp_avg_sq_col"] @@ -78,7 +104,12 @@ def adafactor_step_param(self, p, group): p_data_fp32.add_(-update) - if p.dtype in {torch.float16, torch.bfloat16}: + # if p.dtype in {torch.float16, torch.bfloat16}: + # p.copy_(p_data_fp32) + + if p.dtype == torch.bfloat16: + copy_stochastic_(p, p_data_fp32) + elif p.dtype == torch.float16: p.copy_(p_data_fp32) @@ -101,6 +132,7 @@ def adafactor_step(self, closure=None): return loss + def patch_adafactor_fused(optimizer: Adafactor): optimizer.step_param = adafactor_step_param.__get__(optimizer) optimizer.step = adafactor_step.__get__(optimizer) diff --git a/library/flux_models.py b/library/flux_models.py index e38119cd7..c98d52ec0 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1078,6 +1078,7 @@ def forward( if moving: self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) # print(f"Moved single block {to_cpu_block_index} to cpu.") + to_cpu_block_index += 1 img = img[:, txt.shape[1] :, ...] From 98c91a762513bbce9ebce137da720a448a3da6c9 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Thu, 22 Aug 2024 12:37:41 +0900 Subject: [PATCH 098/348] Fix bug in FLUX multi GPU training --- README.md | 6 +++ flux_train.py | 29 ++++++------- flux_train_network.py | 10 +++-- library/flux_models.py | 6 ++- library/flux_utils.py | 40 ++++++++++++++---- library/strategy_flux.py | 4 +- library/train_util.py | 10 ++--- library/utils.py | 89 ++++++++++++++++++++++++++++++++++++++++ 8 files changed, 156 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 45349ba38..5125c6631 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,12 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 22, 2024: +Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. + +`--disable_mmap_load_safetensors` option now works in `flux_train.py`. It speeds up model loading during training in WSL2. It is also effective in reducing memory usage when loading models during multi-GPU training. Please always check if the model is loaded correctly, as it uses a custom implementation of safetensors loading. + + Aug 21, 2024 (update 3): - There is a bug that `--full_bf16` option is enabled even if it is not specified in `flux_train.py`. The bug will be fixed sooner. __Please specify the `--full_bf16` option explicitly, especially when training with 24GB VRAM.__ - Stochastic rounding is now implemented when `--fused_backward_pass` is specified. The implementation is diff --git a/flux_train.py b/flux_train.py index bcf4b9564..e7d45e04d 100644 --- a/flux_train.py +++ b/flux_train.py @@ -174,7 +174,7 @@ def train(args): # load VAE for caching latents ae = None if cache_latents: - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae.to(accelerator.device, dtype=weight_dtype) ae.requires_grad_(False) ae.eval() @@ -199,8 +199,8 @@ def train(args): strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) # load clip_l, t5xxl for caching text encoder outputs - clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") - t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors) clip_l.eval() t5xxl.eval() clip_l.requires_grad_(False) @@ -228,7 +228,6 @@ def train(args): if args.sample_prompts is not None: logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") - tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() prompts = load_prompts(args.sample_prompts) @@ -238,9 +237,9 @@ def train(args): for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: if p not in sample_prompts_te_outputs: logger.info(f"cache Text Encoder outputs for prompt: {p}") - tokens_and_masks = tokenize_strategy.tokenize(p) + tokens_and_masks = flux_tokenize_strategy.tokenize(p) sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask ) accelerator.wait_for_everyone() @@ -251,7 +250,9 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + flux = flux_utils.load_flow_model( + name, args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) if args.gradient_checkpointing: flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing) @@ -419,7 +420,7 @@ def train(args): # if we doesn't swap blocks, we can move the model to device flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks]) if is_swapping_blocks: - flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -439,8 +440,8 @@ def train(args): double_blocks_to_swap = args.double_blocks_to_swap single_blocks_to_swap = args.single_blocks_to_swap - num_double_blocks = len(flux.double_blocks) - num_single_blocks = len(flux.single_blocks) + num_double_blocks = 19 # len(flux.double_blocks) + num_single_blocks = 38 # len(flux.single_blocks) handled_double_block_indices = set() handled_single_block_indices = set() @@ -537,8 +538,8 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): double_blocks_to_swap = args.double_blocks_to_swap single_blocks_to_swap = args.single_blocks_to_swap - num_double_blocks = len(flux.double_blocks) - num_single_blocks = len(flux.single_blocks) + num_double_blocks = 19 # len(flux.double_blocks) + num_single_blocks = 38 # len(flux.single_blocks) for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: @@ -618,7 +619,7 @@ def optimizer_hook(parameter: torch.Tensor): ) if is_swapping_blocks: - flux.prepare_block_swap_before_forward() + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() # For --sample_at_first flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) @@ -660,7 +661,7 @@ def optimizer_hook(parameter: torch.Tensor): with torch.no_grad(): input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] text_encoder_conds = text_encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask ) if args.full_fp16: text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] diff --git a/flux_train_network.py b/flux_train_network.py index 49bd270c7..3e2057e91 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -57,19 +57,21 @@ def load_target_model(self, args, weight_dtype, accelerator): name = self.get_flux_model_name(args) # if we load to cpu, flux.to(fp8) takes a long time - model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + model = flux_utils.load_flow_model( + name, args.pretrained_model_name_or_path, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + ) if args.split_mode: model = self.prepare_split_model(model, weight_dtype, accelerator) - clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) clip_l.eval() # loading t5xxl to cpu takes a long time, so we should load to gpu in future - t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) t5xxl.eval() - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model diff --git a/library/flux_models.py b/library/flux_models.py index c98d52ec0..c045aef6b 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -745,7 +745,9 @@ def custom_forward(*inputs): return custom_forward - return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask) + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False + ) else: return self._forward(img, txt, vec, pe, txt_attention_mask) @@ -836,7 +838,7 @@ def custom_forward(*inputs): return custom_forward - return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe) + return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False) else: return self._forward(x, vec, pe) diff --git a/library/flux_utils.py b/library/flux_utils.py index 166cd833b..37166933a 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -9,7 +9,7 @@ from library import flux_models -from library.utils import setup_logging +from library.utils import setup_logging, MemoryEfficientSafeOpen setup_logging() import logging @@ -19,32 +19,54 @@ MODEL_VERSION_FLUX_V1 = "flux1" -def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux: +# temporary copy from sd3_utils TODO refactor +def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32): + if disable_mmap: + # return safetensors.torch.load(open(path, "rb").read()) + # use experimental loader + logger.info(f"Loading without mmap (experimental)") + state_dict = {} + with MemoryEfficientSafeOpen(path) as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) + return state_dict + else: + try: + return load_file(path, device=device) + except: + return load_file(path) # prevent device invalid Error + + +def load_flow_model( + name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +) -> flux_models.Flux: logger.info(f"Building Flux model {name}") with torch.device("meta"): model = flux_models.Flux(flux_models.configs[name].params).to(dtype) # load_sft doesn't support torch.device logger.info(f"Loading state dict from {ckpt_path}") - sd = load_file(ckpt_path, device=str(device)) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Flux: {info}") return model -def load_ae(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.AutoEncoder: +def load_ae( + name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +) -> flux_models.AutoEncoder: logger.info("Building AutoEncoder") with torch.device("meta"): ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype) logger.info(f"Loading state dict from {ckpt_path}") - sd = load_file(ckpt_path, device=str(device)) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = ae.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded AE: {info}") return ae -def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> CLIPTextModel: +def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> CLIPTextModel: logger.info("Building CLIP") CLIPL_CONFIG = { "_name_or_path": "clip-vit-large-patch14/", @@ -139,13 +161,13 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev clip = CLIPTextModel._from_config(config) logger.info(f"Loading state dict from {ckpt_path}") - sd = load_file(ckpt_path, device=str(device)) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = clip.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded CLIP: {info}") return clip -def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> T5EncoderModel: +def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel: T5_CONFIG_JSON = """ { "architectures": [ @@ -185,7 +207,7 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi t5xxl = T5EncoderModel._from_config(config) logger.info(f"Loading state dict from {ckpt_path}") - sd = load_file(ckpt_path, device=str(device)) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = t5xxl.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded T5xxl: {info}") return t5xxl diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 737af390a..b3643cbfc 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -137,7 +137,7 @@ def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: t5_attn_mask = data["t5_attn_mask"] if self.apply_t5_attn_mask: - t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) + t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) # FIXME do not mask here!!! return [l_pooled, t5_out, txt_ids, t5_attn_mask] @@ -149,7 +149,7 @@ def cache_batch_outputs( tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): - # attn_mask is not applied when caching to disk: it is applied when loading from disk + # attn_mask is not applied when caching to disk: it is applied when loading from disk FIXME apply mask when loading l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens( tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk ) diff --git a/library/train_util.py b/library/train_util.py index 8929c192f..989758ad5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1104,10 +1104,6 @@ def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: boo caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() batch_size = caching_strategy.batch_size or self.batch_size - # if cache to disk, don't cache TE outputs in non-main process - if caching_strategy.cache_to_disk and not is_main_process: - return - logger.info("caching Text Encoder outputs with caching strategy.") image_infos = list(self.image_data.values()) @@ -1120,9 +1116,9 @@ def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: boo # check disk cache exists and size of latents if caching_strategy.cache_to_disk: - info.text_encoder_outputs_npz = te_out_npz + info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) - if cache_available: # do not add to batch + if cache_available or not is_main_process: # do not add to batch continue batch.append(info) @@ -2638,7 +2634,7 @@ def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset: return train_dataset_group -def load_image(image_path, alpha=False): +def load_image(image_path, alpha=False): try: with Image.open(image_path) as image: if alpha: diff --git a/library/utils.py b/library/utils.py index 7de22d5a9..a16209979 100644 --- a/library/utils.py +++ b/library/utils.py @@ -153,6 +153,95 @@ def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: v.contiguous().view(torch.uint8).numpy().tofile(f) +class MemoryEfficientSafeOpen: + # does not support metadata loading + def __init__(self, filename): + self.filename = filename + self.header, self.header_size = self._read_header() + self.file = open(filename, "rb") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def keys(self): + return [k for k in self.header.keys() if k != "__metadata__"] + + def get_tensor(self, key): + if key not in self.header: + raise KeyError(f"Tensor '{key}' not found in the file") + + metadata = self.header[key] + offset_start, offset_end = metadata["data_offsets"] + + if offset_start == offset_end: + tensor_bytes = None + else: + # adjust offset by header size + self.file.seek(self.header_size + 8 + offset_start) + tensor_bytes = self.file.read(offset_end - offset_start) + + return self._deserialize_tensor(tensor_bytes, metadata) + + def _read_header(self): + with open(self.filename, "rb") as f: + header_size = struct.unpack(" Date: Thu, 22 Aug 2024 19:55:31 +0900 Subject: [PATCH 099/348] Fix --debug_dataset to work. --- flux_train.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/flux_train.py b/flux_train.py index e7d45e04d..410728d44 100644 --- a/flux_train.py +++ b/flux_train.py @@ -142,6 +142,12 @@ def train(args): args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False ) ) + name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" + t5xxl_max_token_length = ( + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if name == "schnell" else 512) + ) + strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) + train_dataset_group.set_current_strategies() train_util.debug_dataset(train_dataset_group, True) return From 2d8fa3387a4adfdc2e36f2582e4ffc21864569f0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Aug 2024 19:56:27 +0900 Subject: [PATCH 100/348] Fix to remove zero pad for t5xxl output --- README.md | 5 +++++ library/strategy_flux.py | 23 +++++++++++------------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 5125c6631..33b3a9a99 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,11 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 22, 2024 (update 2): +Fixed a bug that the embedding was zero-padded when `--apply_t5_attn_mask` option was applied. Also, the cache file for text encoder outputs now records whether the mask is applied or not. Please note that the cache file will be recreated when switching the `--apply_t5_attn_mask` option. + +Added a script to extract LoRA from the difference between the two models of FLUX.1. Use `networks/flux_extract_lora.py`. See `--help` for details. Normally, more than 50GB of memory is required, but specifying the `--mem_eff_safe_open` option significantly reduces memory usage. However, this option is a custom implementation, so unexpected problems may occur. Please always check if the model is loaded correctly. + Aug 22, 2024: Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. diff --git a/library/strategy_flux.py b/library/strategy_flux.py index b3643cbfc..d52b3b8dd 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -22,7 +22,7 @@ class FluxTokenizeStrategy(TokenizeStrategy): - def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None: + def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None: self.t5xxl_max_length = t5xxl_max_length self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) @@ -120,25 +120,24 @@ def is_disk_cached_outputs_expected(self, npz_path: str): return False if "t5_attn_mask" not in npz: return False + if "apply_t5_attn_mask" not in npz: + return False + npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] + if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e return True - def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray: - return t5_out * np.expand_dims(t5_attn_mask, -1) - def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: data = np.load(npz_path) l_pooled = data["l_pooled"] t5_out = data["t5_out"] txt_ids = data["txt_ids"] t5_attn_mask = data["t5_attn_mask"] - - if self.apply_t5_attn_mask: - t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) # FIXME do not mask here!!! - + # apply_t5_attn_mask should be same as self.apply_t5_attn_mask return [l_pooled, t5_out, txt_ids, t5_attn_mask] def cache_batch_outputs( @@ -149,10 +148,8 @@ def cache_batch_outputs( tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): - # attn_mask is not applied when caching to disk: it is applied when loading from disk FIXME apply mask when loading - l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk - ) + # attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True + l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks) if l_pooled.dtype == torch.bfloat16: l_pooled = l_pooled.float() @@ -171,6 +168,7 @@ def cache_batch_outputs( t5_out_i = t5_out[i] txt_ids_i = txt_ids[i] t5_attn_mask_i = t5_attn_mask[i] + apply_t5_attn_mask_i = self.apply_t5_attn_mask if self.cache_to_disk: np.savez( @@ -179,6 +177,7 @@ def cache_batch_outputs( t5_out=t5_out_i, txt_ids=txt_ids_i, t5_attn_mask=t5_attn_mask_i, + apply_t5_attn_mask=apply_t5_attn_mask_i, ) else: info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) From b0a980844a2e02b1b1ae4cf615ae489dbf8ece67 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Aug 2024 19:57:29 +0900 Subject: [PATCH 101/348] added a script to extract LoRA --- networks/flux_extract_lora.py | 219 ++++++++++++++++++++++++++++++++++ 1 file changed, 219 insertions(+) create mode 100644 networks/flux_extract_lora.py diff --git a/networks/flux_extract_lora.py b/networks/flux_extract_lora.py new file mode 100644 index 000000000..3ee6e816d --- /dev/null +++ b/networks/flux_extract_lora.py @@ -0,0 +1,219 @@ +# extract approximating LoRA by svd from two FLUX models +# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo! + +import argparse +import json +import os +import time +import torch +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from tqdm import tqdm +from library import flux_utils, sai_model_spec, model_util, sdxl_model_util +import lora +from library.utils import MemoryEfficientSafeOpen +from library.utils import setup_logging +from networks import lora_flux + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +# CLAMP_QUANTILE = 0.99 +# MIN_DIFF = 1e-1 + + +def save_to_file(file_name, state_dict, metadata, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + save_file(state_dict, file_name, metadata=metadata) + + +def svd( + model_org=None, + model_tuned=None, + save_to=None, + dim=4, + device=None, + save_precision=None, + clamp_quantile=0.99, + min_diff=0.01, + no_metadata=False, + mem_eff_safe_open=False, +): + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + calc_dtype = torch.float + save_dtype = str_to_dtype(save_precision) + store_device = "cpu" + + # open models + lora_weights = {} + if not mem_eff_safe_open: + # use original safetensors.safe_open + open_fn = lambda fn: safe_open(fn, framework="pt") + else: + logger.info("Using memory efficient safe_open") + open_fn = lambda fn: MemoryEfficientSafeOpen(fn) + + with open_fn(model_org) as fo: + # filter keys + keys = [] + for key in fo.keys(): + if not ("single_block" in key or "double_block" in key): + continue + if ".bias" in key: + continue + if "norm" in key: + continue + keys.append(key) + + with open_fn(model_tuned) as ft: + for key in tqdm(keys): + # get tensors and calculate difference + value_o = fo.get_tensor(key) + value_t = ft.get_tensor(key) + mat = value_t.to(calc_dtype) - value_o.to(calc_dtype) + del value_o, value_t + + # extract LoRA weights + if device: + mat = mat.to(device) + out_dim, in_dim = mat.size()[0:2] + rank = min(dim, in_dim, out_dim) # LoRA rank cannot exceed the original dim + + mat = mat.squeeze() + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, clamp_quantile) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + U = U.to(store_device, dtype=save_dtype).contiguous() + Vh = Vh.to(store_device, dtype=save_dtype).contiguous() + + print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}") + lora_weights[key] = (U, Vh) + del mat, U, S, Vh + + # make state dict for LoRA + lora_sd = {} + for key, (up_weight, down_weight) in lora_weights.items(): + lora_name = key.replace(".weight", "").replace(".", "_") + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + lora_name + lora_sd[lora_name + ".lora_up.weight"] = up_weight + lora_sd[lora_name + ".lora_down.weight"] = down_weight + lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) # same as rank + + # minimum metadata + net_kwargs = {} + metadata = { + "ss_v2": str(False), + "ss_base_model_version": flux_utils.MODEL_VERSION_FLUX_V1, + "ss_network_module": "networks.lora_flux", + "ss_network_dim": str(dim), + "ss_network_alpha": str(float(dim)), + "ss_network_args": json.dumps(net_kwargs), + } + + if not no_metadata: + title = os.path.splitext(os.path.basename(save_to))[0] + sai_metadata = sai_model_spec.build_metadata(lora_sd, False, False, False, True, False, time.time(), title, flux="dev") + metadata.update(sai_metadata) + + save_to_file(save_to, lora_sd, metadata, save_dtype) + + logger.info(f"LoRA weights saved to {save_to}") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat", + ) + parser.add_argument( + "--model_org", + type=str, + default=None, + required=True, + help="Original model: safetensors file / 元モデル、safetensors", + ) + parser.add_argument( + "--model_tuned", + type=str, + default=None, + required=True, + help="Tuned model, LoRA is difference of `original to tuned`: safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", + ) + parser.add_argument( + "--mem_eff_safe_open", + action="store_true", + help="use memory efficient safe_open. This is an experimental feature, use only when memory is not enough." + " / メモリ効率の良いsafe_openを使用する。実装は実験的なものなので、メモリが足りない場合のみ使用してください。", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + required=True, + help="destination file name: safetensors file / 保存先のファイル名、safetensors", + ) + parser.add_argument( + "--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)" + ) + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) + parser.add_argument( + "--clamp_quantile", + type=float, + default=0.99, + help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99", + ) + # parser.add_argument( + # "--min_diff", + # type=float, + # default=0.01, + # help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /" + # + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01", + # ) + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + svd(**vars(args)) From bf9f798985dd75fc2dd1fbc8c8dc775c92176854 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Aug 2024 19:59:38 +0900 Subject: [PATCH 102/348] chore: fix typos, remove debug print --- networks/flux_extract_lora.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/networks/flux_extract_lora.py b/networks/flux_extract_lora.py index 3ee6e816d..63ab2960c 100644 --- a/networks/flux_extract_lora.py +++ b/networks/flux_extract_lora.py @@ -68,10 +68,10 @@ def str_to_dtype(p): logger.info("Using memory efficient safe_open") open_fn = lambda fn: MemoryEfficientSafeOpen(fn) - with open_fn(model_org) as fo: + with open_fn(model_org) as f_org: # filter keys keys = [] - for key in fo.keys(): + for key in f_org.keys(): if not ("single_block" in key or "double_block" in key): continue if ".bias" in key: @@ -80,11 +80,11 @@ def str_to_dtype(p): continue keys.append(key) - with open_fn(model_tuned) as ft: + with open_fn(model_tuned) as f_tuned: for key in tqdm(keys): # get tensors and calculate difference - value_o = fo.get_tensor(key) - value_t = ft.get_tensor(key) + value_o = f_org.get_tensor(key) + value_t = f_tuned.get_tensor(key) mat = value_t.to(calc_dtype) - value_o.to(calc_dtype) del value_o, value_t @@ -114,7 +114,7 @@ def str_to_dtype(p): U = U.to(store_device, dtype=save_dtype).contiguous() Vh = Vh.to(store_device, dtype=save_dtype).contiguous() - print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}") + # print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}") lora_weights[key] = (U, Vh) del mat, U, S, Vh From 81411a398eb4ce28d84cc2da8238ff013d40d62f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Aug 2024 22:02:29 +0900 Subject: [PATCH 103/348] speed up getting image sizes --- library/strategy_base.py | 7 ++++++- library/strategy_flux.py | 9 +++------ library/strategy_sd.py | 12 ++++-------- library/strategy_sd3.py | 9 +++------ library/train_util.py | 23 ++++++++++++++++++++++- 5 files changed, 38 insertions(+), 22 deletions(-) diff --git a/library/strategy_base.py b/library/strategy_base.py index e7d3a97ef..6a01c30a5 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -204,9 +204,14 @@ def cache_to_disk(self): def batch_size(self): return self._batch_size - def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + @property + def cache_suffix(self): raise NotImplementedError + def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]: + w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x") + return int(w), int(h) + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: raise NotImplementedError diff --git a/library/strategy_flux.py b/library/strategy_flux.py index d52b3b8dd..887113ca1 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -189,12 +189,9 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy): def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) - def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX) - if len(npz_file) == 0: - return None, None - w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") - return int(w), int(h) + @property + def cache_suffix(self) -> str: + return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: return ( diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 83ffaa31b..af472e491 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -108,14 +108,10 @@ def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cac self.suffix = ( SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX ) - - def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - # does not include old npz - npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + self.suffix) - if len(npz_file) == 0: - return None, None - w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") - return int(w), int(h) + + @property + def cache_suffix(self) -> str: + return self.suffix def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: # support old .npz diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index a22818903..9fde02084 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -222,12 +222,9 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy): def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) - def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) - if len(npz_file) == 0: - return None, None - w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") - return int(w), int(h) + @property + def cache_suffix(self) -> str: + return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: return ( diff --git a/library/train_util.py b/library/train_util.py index 989758ad5..dcc01f6f7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1739,9 +1739,30 @@ def load_dreambooth_dir(subset: DreamBoothSubset): strategy = LatentsCachingStrategy.get_strategy() if strategy is not None: logger.info("get image size from name of cache files") + + # make image path to npz path mapping + npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix)) + npz_paths.sort() + npz_path_index = 0 + size_set_count = 0 for i, img_path in enumerate(tqdm(img_paths)): - w, h = strategy.get_image_size_from_disk_cache_path(img_path) + l = len(os.path.splitext(img_path)[0]) # remove extension + found = False + while npz_path_index < len(npz_paths): # until found or end of npz_paths + # npz_paths are sorted, so if npz_path > img_path, img_path is not found + if npz_paths[npz_path_index][:l] > img_path[:l]: + break + if npz_paths[npz_path_index][:l] == img_path[:l]: # found + found = True + break + npz_path_index += 1 # next npz_path + + if found: + w, h = strategy.get_image_size_from_disk_cache_path(img_path, npz_paths[npz_path_index]) + else: + w, h = None, None + if w is not None and h is not None: sizes[i] = [w, h] size_set_count += 1 From 2e89cd2cc634c27add7a04c21fcb6d0e16716a2b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Aug 2024 12:39:54 +0900 Subject: [PATCH 104/348] Fix issue with attention mask not being applied in single blocks --- README.md | 3 ++ flux_train_network.py | 4 +-- library/flux_models.py | 62 +++++++++++++++++++++--------------------- 3 files changed, 36 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 33b3a9a99..4151bf44e 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 24, 2024: +Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified. + Aug 22, 2024 (update 2): Fixed a bug that the embedding was zero-padded when `--apply_t5_attn_mask` option was applied. Also, the cache file for text encoder outputs now records whether the mask is applied or not. Please note that the cache file will be recreated when switching the `--apply_t5_attn_mask` option. diff --git a/flux_train_network.py b/flux_train_network.py index 3e2057e91..82f77a77e 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -243,7 +243,7 @@ def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_a self.flux_upper.to("cpu") clean_memory_on_device(self.target_device) self.flux_lower.to(self.target_device) - return self.flux_lower(img, txt, vec, pe) + return self.flux_lower(img, txt, vec, pe, txt_attention_mask) wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) clean_memory_on_device(accelerator.device) @@ -352,7 +352,7 @@ def get_noise_pred_and_target( intermediate_txt.requires_grad_(True) vec.requires_grad_(True) pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) diff --git a/library/flux_models.py b/library/flux_models.py index c045aef6b..b5726c298 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -752,18 +752,6 @@ def custom_forward(*inputs): else: return self._forward(img, txt, vec, pe, txt_attention_mask) - # def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): - # if self.training and self.gradient_checkpointing: - # def create_custom_forward(func): - # def custom_forward(*inputs): - # return func(*inputs) - # return custom_forward - # return torch.utils.checkpoint.checkpoint( - # create_custom_forward(self._forward), img, txt, vec, pe, use_reentrant=USE_REENTRANT - # ) - # else: - # return self._forward(img, txt, vec, pe) - class SingleStreamBlock(nn.Module): """ @@ -809,7 +797,7 @@ def disable_gradient_checkpointing(self): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False - def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: mod, _ = self.modulation(vec) x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) @@ -817,16 +805,35 @@ def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k = self.norm(q, k, v) + # make attention mask if not None + attn_mask = None + if txt_attention_mask is not None: + # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask + attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len + attn_mask = torch.cat( + ( + attn_mask, + torch.ones( + attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool + ), + ), + dim=1, + ) # b, seq_len + img_len = x_len + + # broadcast attn_mask to all heads + attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + # compute attention - attn = attention(q, k, v, pe=pe) + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) return x + mod.gate * output - def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: if self.training and self.gradient_checkpointing: if not self.cpu_offload_checkpointing: - return checkpoint(self._forward, x, vec, pe, use_reentrant=False) + return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False) # cpu offload checkpointing @@ -838,19 +845,11 @@ def custom_forward(*inputs): return custom_forward - return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False) + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False + ) else: - return self._forward(x, vec, pe) - - # def forward(self, x: Tensor, vec: Tensor, pe: Tensor): - # if self.training and self.gradient_checkpointing: - # def create_custom_forward(func): - # def custom_forward(*inputs): - # return func(*inputs) - # return custom_forward - # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=USE_REENTRANT) - # else: - # return self._forward(x, vec, pe) + return self._forward(x, vec, pe, txt_attention_mask) class LastLayer(nn.Module): @@ -1053,7 +1052,7 @@ def forward( if not self.single_blocks_to_swap: for block in self.single_blocks: - img = block(img, vec=vec, pe=pe) + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning for block_idx in range(self.single_blocks_to_swap): @@ -1075,7 +1074,7 @@ def forward( block.to(self.device) # move to cuda # print(f"Moved single block {block_idx} to cuda.") - img = block(img, vec=vec, pe=pe) + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) if moving: self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) @@ -1250,10 +1249,11 @@ def forward( txt: Tensor, vec: Tensor | None = None, pe: Tensor | None = None, + txt_attention_mask: Tensor | None = None, ) -> Tensor: img = torch.cat((txt, img), 1) for block in self.single_blocks: - img = block(img, vec=vec, pe=pe) + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) img = img[:, txt.shape[1] :, ...] img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) From cf689e7aa697877a0eee58622035ab702ce59d3e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Aug 2024 16:35:43 +0900 Subject: [PATCH 105/348] feat: Add option to split projection layers and apply LoRA --- README.md | 14 ++ networks/check_lora_weights.py | 2 +- networks/convert_flux_lora.py | 51 ++++-- networks/lora_flux.py | 326 +++++++++++++++++++++++++++------ 4 files changed, 325 insertions(+), 68 deletions(-) diff --git a/README.md b/README.md index 4151bf44e..7d326a867 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,20 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 24, 2024 (update 2): + +__Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). + +The number of parameters may increase slightly, so the expressiveness may increase, but the training time may be longer. No detailed verification has been done. + +This implementation is experimental, so it may be deprecated or changed in the future. + +The .safetensors file of the trained model is compatible with the normal LoRA model of sd-scripts, so it should be usable in inference environments such as ComfyUI as it is. Also, converting it to AI-toolkit (Diffusers) format with `convert_flux_lora.py` will reduce the size. It should be no problem to convert it if you use it in the inference environment. + +Technical details: In the implementation of Black Forest Labs' model, the projection layers of q/k/v (and txt in single blocks) are concatenated into one. If LoRA is added there as it is, the LoRA module is only one, and the dimension is large. In contrast, in the implementation of Diffusers, the projection layers of q/k/v/txt are separated. Therefore, the LoRA module is applied to q/k/v/txt separately, and the dimension is smaller. This option is for training LoRA similar to the latter. + +The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. + Aug 24, 2024: Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified. diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 794659c94..b5b5e61ae 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -18,7 +18,7 @@ def main(file): keys = list(sd.keys()) for key in keys: - if "lora_up" in key or "lora_down" in key: + if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key: values.append((key, sd[key])) print(f"number of LoRA modules: {len(values)}") diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py index e9743534d..bd4c1cf78 100644 --- a/networks/convert_flux_lora.py +++ b/networks/convert_flux_lora.py @@ -266,11 +266,12 @@ def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): if sds_key + ".lora_down.weight" not in sds_sd: return down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + sd_lora_rank = down_weight.shape[0] # scale weight by alpha and dim - rank = down_weight.shape[0] alpha = sds_sd.pop(sds_key + ".alpha") - scale = alpha / rank + scale = alpha / sd_lora_rank # calculate scale_down and scale_up scale_down = scale @@ -279,23 +280,49 @@ def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): scale_down *= 2 scale_up /= 2 - ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] - ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] - - num_splits = len(ait_keys) - up_weight = sds_sd.pop(sds_key + ".lora_up.weight") - - # down_weight is copied to each split - ait_sd.update({k: down_weight * scale_down for k in ait_down_keys}) + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up # calculate dims if not provided + num_splits = len(ait_keys) if dims is None: dims = [up_weight.shape[0] // num_splits] * num_splits else: assert sum(dims) == up_weight.shape[0] - # up_weight is split to each split - ait_sd.update({k: v * scale_up for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) + # check upweight is sparse or not + is_sparse = False + if sd_lora_rank % num_splits == 0: + ait_rank = sd_lora_rank // num_splits + is_sparse = True + i = 0 + for j in range(len(dims)): + for k in range(len(dims)): + if j == k: + continue + is_sparse = is_sparse and torch.all(up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0) + i += dims[j] + if is_sparse: + logger.info(f"weight is sparse: {sds_key}") + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + if not is_sparse: + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) + else: + # down_weight is chunked to each split + ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) + + # up_weight is sparse: only non-zero values are copied to each split + i = 0 + for j in range(len(dims)): + ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() + i += dims[j] def convert_sd_scripts_to_ai_toolkit(sds_sd): diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 4da33542f..efc7847ed 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -39,6 +39,7 @@ def __init__( dropout=None, rank_dropout=None, module_dropout=None, + split_dims: Optional[List[int]] = None, ): """if alpha == 0 or None, alpha is rank (no scaling).""" super().__init__() @@ -52,16 +53,34 @@ def __init__( out_dim = org_module.out_features self.lora_dim = lora_dim + self.split_dims = split_dims + + if split_dims is None: + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) - if org_module.__class__.__name__ == "Conv2d": - kernel_size = org_module.kernel_size - stride = org_module.stride - padding = org_module.padding - self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) - self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) else: - self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) - self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + # conv2d not supported + assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" + assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" + # print(f"split_dims: {split_dims}") + self.lora_down = torch.nn.ModuleList( + [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] + ) + self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) + for lora_down in self.lora_down: + torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + for lora_up in self.lora_up: + torch.nn.init.zeros_(lora_up.weight) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error @@ -70,9 +89,6 @@ def __init__( self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える # same as microsoft's - torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(self.lora_up.weight) - self.multiplier = multiplier self.org_module = org_module # remove in applying self.dropout = dropout @@ -92,30 +108,56 @@ def forward(self, x): if torch.rand(1) < self.module_dropout: return org_forwarded - lx = self.lora_down(x) - - # normal dropout - if self.dropout is not None and self.training: - lx = torch.nn.functional.dropout(lx, p=self.dropout) + if self.split_dims is None: + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale - # rank dropout - if self.rank_dropout is not None and self.training: - mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout - if len(lx.size()) == 3: - mask = mask.unsqueeze(1) # for Text Encoder - elif len(lx.size()) == 4: - mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d - lx = lx * mask + lx = self.lora_up(lx) - # scaling for rank dropout: treat as if the rank is changed - # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる - scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + return org_forwarded + lx * self.multiplier * scale else: - scale = self.scale + lxs = [lora_down(x) for lora_down in self.lora_down] + + # normal dropout + if self.dropout is not None and self.training: + lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] + + # rank dropout + if self.rank_dropout is not None and self.training: + masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] + for i in range(len(lxs)): + if len(lx.size()) == 3: + masks[i] = masks[i].unsqueeze(1) + elif len(lx.size()) == 4: + masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) + lxs[i] = lxs[i] * masks[i] + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale - lx = self.lora_up(lx) + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] - return org_forwarded + lx * self.multiplier * scale + return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale class LoRAInfModule(LoRAModule): @@ -152,31 +194,50 @@ def merge_to(self, sd, dtype, device): if device is None: device = org_device - # get up/down weight - up_weight = sd["lora_up.weight"].to(torch.float).to(device) - down_weight = sd["lora_down.weight"].to(torch.float).to(device) - - # merge weight - if len(weight.size()) == 2: - # linear - weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + self.multiplier - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * self.scale - ) + if self.split_dims is None: + # get up/down weight + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + self.multiplier * conved * self.scale + # split_dims + total_dims = sum(self.split_dims) + for i in range(len(self.split_dims)): + # get up/down weight + down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim) + up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank) + + # pad up_weight -> (total_dims, rank) + padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float) + padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight + + # merge weight + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale - # set weight to org_module - org_sd["weight"] = weight.to(dtype) - self.org_module.load_state_dict(org_sd) + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) # 復元できるマージのため、このモジュールのweightを返す def get_weight(self, multiplier=None): @@ -211,7 +272,14 @@ def set_region(self, region): def default_forward(self, x): # logger.info(f"default_forward {self.lora_name} {x.size()}") - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + if self.split_dims is None: + lx = self.lora_down(x) + lx = self.lora_up(lx) + return self.org_forward(x) + lx * self.multiplier * self.scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale def forward(self, x): if not self.enabled: @@ -257,6 +325,11 @@ def create_network( if train_blocks is not None: assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -270,6 +343,7 @@ def create_network( conv_lora_dim=conv_dim, conv_alpha=conv_alpha, train_blocks=train_blocks, + split_qkv=split_qkv, varbose=True, ) @@ -311,10 +385,34 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_dim[lora_name] = dim # logger.info(lora_name, value.size(), dim) + # # split qkv + # double_qkv_rank = None + # single_qkv_rank = None + # rank = None + # for lora_name, dim in modules_dim.items(): + # if "double" in lora_name and "qkv" in lora_name: + # double_qkv_rank = dim + # elif "single" in lora_name and "linear1" in lora_name: + # single_qkv_rank = dim + # elif rank is None: + # rank = dim + # if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None: + # break + # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( + # single_qkv_rank is not None and single_qkv_rank != rank + # ) + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + module_class = LoRAInfModule if for_inference else LoRAModule network = LoRANetwork( - text_encoders, flux, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + text_encoders, + flux, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, ) return network, weights_sd @@ -344,6 +442,7 @@ def __init__( modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, train_blocks: Optional[str] = None, + split_qkv: bool = False, varbose: Optional[bool] = False, ) -> None: super().__init__() @@ -357,6 +456,7 @@ def __init__( self.rank_dropout = rank_dropout self.module_dropout = module_dropout self.train_blocks = train_blocks if train_blocks is not None else "all" + self.split_qkv = split_qkv self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -373,6 +473,8 @@ def __init__( logger.info( f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" ) + if self.split_qkv: + logger.info(f"split qkv for LoRA") # create module instances def create_modules( @@ -420,6 +522,14 @@ def create_modules( skipped.append(lora_name) continue + # qkv split + split_dims = None + if is_flux and split_qkv: + if "double" in lora_name and "qkv" in lora_name: + split_dims = [3072] * 3 + elif "single" in lora_name and "linear1" in lora_name: + split_dims = [3072] * 3 + [12288] + lora = module_class( lora_name, child_module, @@ -429,6 +539,7 @@ def create_modules( dropout=dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, + split_dims=split_dims, ) loras.append(lora) return loras, skipped @@ -492,6 +603,111 @@ def load_weights(self, file): info = self.load_state_dict(weights_sd, False) return info + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to splitted qkv weight + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # split qkv + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + continue + + weight = state_dict[key] + lora_name = key.split(".")[0] + if "lora_down" in key and "weight" in key: + # dense weight (rank*3, in_dim) + split_weight = torch.chunk(weight, len(split_dims), dim=0) + for i, split_w in enumerate(split_weight): + state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w + + del state_dict[key] + # print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}") + elif "lora_up" in key and "weight" in key: + # sparse weight (out_dim=sum(split_dims), rank*3) + rank = weight.size(1) // len(split_dims) + i = 0 + for j in range(len(split_dims)): + state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dims[j], j * rank : (j + 1) * rank] + i += split_dims[j] + del state_dict[key] + + # # check is sparse + # i = 0 + # is_zero = True + # for j in range(len(split_dims)): + # for k in range(len(split_dims)): + # if j == k: + # continue + # is_zero = is_zero and torch.all(weight[i : i + split_dims[j], k * rank : (k + 1) * rank] == 0) + # i += split_dims[j] + # if not is_zero: + # logger.warning(f"weight is not sparse: {key}") + # else: + # logger.info(f"weight is sparse: {key}") + + # print( + # f"split {key}: {weight.shape} to {[state_dict[k].shape for k in [f'{lora_name}.lora_up.{j}.weight' for j in range(len(split_dims))]]}" + # ) + + # alpha is unchanged + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + rank = up_weights[0].size(1) + up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(len(split_dims)): + up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j] + i += split_dims[j] + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") From 5639c2adc0085e2e995bb3eee5a278aace397e7a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Aug 2024 16:37:49 +0900 Subject: [PATCH 106/348] fix typo --- networks/lora_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index efc7847ed..07a80f0bf 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -604,7 +604,7 @@ def load_weights(self, file): return info def load_state_dict(self, state_dict, strict=True): - # override to convert original weight to splitted qkv weight + # override to convert original weight to split qkv if not self.split_qkv: return super().load_state_dict(state_dict, strict) From 72287d39c76176c0e1c16e8da4f5ddc6f94ea7d6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 25 Aug 2024 16:01:24 +0900 Subject: [PATCH 107/348] feat: Add `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training --- README.md | 4 ++++ library/flux_train_utils.py | 15 +++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 282f3b3bd..562dcdb2a 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,10 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 25, 2024: +Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. +Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` + Aug 24, 2024 (update 2): __Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 1d3f80d72..75f70a54f 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -380,9 +380,19 @@ def get_noisy_model_input_and_timesteps( t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: t = torch.rand((bsz,), device=device) + timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "shift": + shift = args.discrete_flow_shift + logits_norm = torch.randn(bsz, device=device) + timesteps = logits_norm.sigmoid() + timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -559,9 +569,10 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--timestep_sampling", - choices=["sigma", "uniform", "sigmoid"], + choices=["sigma", "uniform", "sigmoid", "shift"], default="sigma", - help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", ) parser.add_argument( "--sigmoid_scale", From 0087a46e14c8e568982cbe3a5d9b9c561b175abf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 27 Aug 2024 19:59:40 +0900 Subject: [PATCH 108/348] FLUX.1 LoRA supports CLIP-L --- README.md | 8 ++++ flux_train_network.py | 40 +++++++++++++----- library/flux_train_utils.py | 8 ++-- library/strategy_flux.py | 3 +- networks/lora_flux.py | 4 +- train_network.py | 81 ++++++++++++++++++++++++------------- 6 files changed, 101 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 562dcdb2a..1203b5ebc 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,14 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 27, 2024: + +- FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI. + - `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. +- `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. + +- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is not required (Flux is fp8, and CLIP-L is bf16/fp16, regardless of the `--fp8_base` option). + Aug 25, 2024: Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` diff --git a/flux_train_network.py b/flux_train_network.py index 82f77a77e..1a40de61a 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -40,9 +40,13 @@ def assert_extra_args(self, args, train_dataset_group): train_dataset_group.is_text_encoder_output_cacheable() ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - assert ( - args.network_train_unet_only or not args.cache_text_encoder_outputs - ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + # assert ( + # args.network_train_unet_only or not args.cache_text_encoder_outputs + # ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + if not args.network_train_unet_only: + logger.info( + "network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません" + ) if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") @@ -137,12 +141,25 @@ def get_text_encoding_strategy(self, args): return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) def get_models_for_text_encoding(self, args, accelerator, text_encoders): - return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])] + if args.cache_text_encoder_outputs: + if self.is_train_text_encoder(args): + return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached + else: + return text_encoders # ignored + else: + return text_encoders # both CLIP-L and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [True, False] if self.is_train_text_encoder(args) else [False, False] def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: return strategy_flux.FluxTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False, apply_t5_attn_mask=args.apply_t5_attn_mask + args.cache_text_encoder_outputs_to_disk, + None, + False, + is_partial=self.is_train_text_encoder(args), + apply_t5_attn_mask=args.apply_t5_attn_mask, ) else: return None @@ -190,9 +207,11 @@ def cache_text_encoder_outputs_if_needed( accelerator.wait_for_everyone() # move back to cpu - logger.info("move text encoders back to cpu") - text_encoders[0].to("cpu") # , dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU - text_encoders[1].to("cpu") # , dtype=torch.float32) + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[1].to("cpu") clean_memory_on_device(accelerator.device) if not args.lowram: @@ -297,7 +316,8 @@ def get_noise_pred_and_target( if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) for t in text_encoder_conds: - t.requires_grad_(True) + if t.dtype.is_floating_point: + t.requires_grad_(True) img_ids.requires_grad_(True) guidance_vec.requires_grad_(True) @@ -384,7 +404,7 @@ def update_metadata(self, metadata, args): metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift def is_text_encoder_not_needed_for_training(self, args): - return args.cache_text_encoder_outputs + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) def setup_parser() -> argparse.ArgumentParser: diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 75f70a54f..a8e94ac00 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -58,7 +58,7 @@ def sample_images( logger.info("") logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") - if not os.path.isfile(args.sample_prompts): + if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None: logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return @@ -66,7 +66,8 @@ def sample_images( # unwrap unet and text_encoder(s) flux = accelerator.unwrap_model(flux) - text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + if text_encoders is not None: + text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) prompts = load_prompts(args.sample_prompts) @@ -134,7 +135,7 @@ def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, flux: flux_models.Flux, - text_encoders: List[CLIPTextModel], + text_encoders: Optional[List[CLIPTextModel]], ae: flux_models.AutoEncoder, save_dir, prompt_dict, @@ -387,6 +388,7 @@ def get_noisy_model_input_and_timesteps( elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index d52b3b8dd..5d0839132 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -60,7 +60,7 @@ def encode_tokens( if apply_t5_attn_mask is None: apply_t5_attn_mask = self.apply_t5_attn_mask - clip_l, t5xxl = models + clip_l, t5xxl = models if len(models) == 2 else (models[0], None) l_tokens, t5_tokens = tokens[:2] t5_attn_mask = tokens[2] if len(tokens) > 2 else None @@ -81,6 +81,7 @@ def encode_tokens( else: t5_out = None txt_ids = None + t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 07a80f0bf..fcb56a467 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -401,7 +401,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( # single_qkv_rank is not None and single_qkv_rank != rank # ) - split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined module_class = LoRAInfModule if for_inference else LoRAModule @@ -421,7 +421,7 @@ class LoRANetwork(torch.nn.Module): # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" diff --git a/train_network.py b/train_network.py index cab0ec52e..048c7e7bd 100644 --- a/train_network.py +++ b/train_network.py @@ -127,8 +127,15 @@ def get_text_encoder_outputs_caching_strategy(self, args): return None def get_models_for_text_encoding(self, args, accelerator, text_encoders): + """ + Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models. + """ return text_encoders + # returns a list of bool values indicating whether each text encoder should be trained + def get_text_encoders_train_flags(self, args, text_encoders): + return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders) + def is_train_text_encoder(self, args): return not args.network_train_unet_only @@ -136,11 +143,6 @@ def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, tex for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype) - return encoder_hidden_states - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred @@ -313,7 +315,7 @@ def train(self, args): collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: - train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: @@ -437,8 +439,10 @@ def train(self, args): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - for t_enc in text_encoders: - t_enc.gradient_checkpointing_enable() + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): + if flag: + if t_enc.supports_gradient_checkpointing: + t_enc.gradient_checkpointing_enable() del t_enc network.enable_gradient_checkpointing() # may have no effect @@ -522,14 +526,17 @@ def train(self, args): unet_weight_dtype = te_weight_dtype = weight_dtype # Experimental Feature: Put base model into fp8 to save vram - if args.fp8_base: + if args.fp8_base or args.fp8_base_unet: assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" assert ( args.mixed_precision != "no" ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" - accelerator.print("enable fp8 training.") + accelerator.print("enable fp8 training for U-Net.") unet_weight_dtype = torch.float8_e4m3fn - te_weight_dtype = torch.float8_e4m3fn + + if not args.fp8_base_unet: + accelerator.print("enable fp8 training for Text Encoder.") + te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory @@ -546,19 +553,18 @@ def train(self, args): t_enc.to(dtype=te_weight_dtype) if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"): # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to( - dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"): - t_enc.encoder.embeddings.to( - dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: + flags = self.get_text_encoders_train_flags(args, text_encoders) ds_model = deepspeed_utils.prepare_deepspeed_model( args, unet=unet if train_unet else None, - text_encoder1=text_encoders[0] if train_text_encoder else None, - text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None, + text_encoder1=text_encoders[0] if flags[0] else None, + text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None, network=network, ) ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -571,11 +577,14 @@ def train(self, args): else: unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: + text_encoders = [ + (accelerator.prepare(t_enc) if flag else t_enc) + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)) + ] if len(text_encoders) > 1: - text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] + text_encoder = text_encoders else: - text_encoder = accelerator.prepare(text_encoder) - text_encoders = [text_encoder] + text_encoder = text_encoders[0] else: pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set @@ -587,11 +596,11 @@ def train(self, args): if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() - for t_enc in text_encoders: + for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): t_enc.train() # set top parameter requires_grad = True for gradient checkpointing works - if train_text_encoder: + if frag: t_enc.text_model.embeddings.requires_grad_(True) else: @@ -736,6 +745,7 @@ def load_model_hook(models, input_dir): "ss_huber_schedule": args.huber_schedule, "ss_huber_c": args.huber_c, "ss_fp8_base": args.fp8_base, + "ss_fp8_base_unet": args.fp8_base_unet, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1004,6 +1014,7 @@ def remove_model(old_ckpt_name): for t_enc in text_encoders: del t_enc text_encoders = [] + text_encoder = None # For --sample_at_first self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) @@ -1018,7 +1029,7 @@ def remove_model(old_ckpt_name): # log device and dtype for each model logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") for t_enc in text_encoders: - logger.info(f"text_encoder dtype: {te_weight_dtype}, device: {t_enc.device}") + logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}") clean_memory_on_device(accelerator.device) @@ -1073,12 +1084,17 @@ def remove_model(old_ckpt_name): text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - else: + if ( + text_encoder_conds is None + or len(text_encoder_conds) == 0 + or text_encoder_conds[0] is None + or train_text_encoder + ): with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: # SD only - text_encoder_conds = get_weighted_text_embeddings( + encoded_text_encoder_conds = get_weighted_text_embeddings( tokenizers[0], text_encoder, batch["captions"], @@ -1088,13 +1104,18 @@ def remove_model(old_ckpt_name): ) else: input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] - text_encoder_conds = text_encoding_strategy.encode_tokens( + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids, ) if args.full_fp16: - text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( @@ -1257,6 +1278,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument( + "--fp8_base_unet", + action="store_true", + help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16" + " / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16", + ) parser.add_argument( "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" From 3be712e3e011b0378fad389641cec0c1869555ab Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 27 Aug 2024 21:40:02 +0900 Subject: [PATCH 109/348] feat: Update direct loading fp8 ckpt for LoRA training --- README.md | 7 +++- flux_minimal_inference.py | 27 +----------- flux_train_network.py | 16 +++++++- library/flux_utils.py | 12 ++++-- library/utils.py | 62 +++++++++++++++++++++++++++- networks/flux_merge_lora.py | 82 ++++++++++++++++++++++++++----------- 6 files changed, 151 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index 1203b5ebc..0108ada59 100644 --- a/README.md +++ b/README.md @@ -9,13 +9,18 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 27, 2024 (update 2): +In FLUX.1 LoRA training, when `--fp8_base` is specified, the FLUX.1 model file with fp8 (`float8_e4m3fn` type) can be loaded directly. Also, in `flux_minimal_inference.py`, it is possible to load it by specifying `fp8 (float8_e4m3fn)` in `--flux_dtype`. + +In `flux_merge_lora.py`, you can now specify the precision at save time with `fp8` (see `--help` for details). Also, if you do not specify the merge model, only the model type conversion will be performed. + Aug 27, 2024: - FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI. - `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. - `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. -- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is not required (Flux is fp8, and CLIP-L is bf16/fp16, regardless of the `--fp8_base` option). +- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is automatically enabled. Aug 25, 2024: Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 5b8aa2506..56c1b1982 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -10,7 +10,6 @@ import numpy as np import torch -from safetensors.torch import safe_open, load_file from tqdm import tqdm from PIL import Image import accelerate @@ -21,7 +20,7 @@ init_ipex() -from library.utils import setup_logging +from library.utils import setup_logging, str_to_dtype setup_logging() import logging @@ -288,28 +287,6 @@ def generate_image( name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way is_schnell = name == "schnell" - def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: - if s is None: - return default_dtype - if s in ["bf16", "bfloat16"]: - return torch.bfloat16 - elif s in ["fp16", "float16"]: - return torch.float16 - elif s in ["fp32", "float32"]: - return torch.float32 - elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: - return torch.float8_e4m3fn - elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: - return torch.float8_e4m3fnuz - elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: - return torch.float8_e5m2 - elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: - return torch.float8_e5m2fnuz - elif s in ["fp8", "float8"]: - return torch.float8_e4m3fn # default fp8 - else: - raise ValueError(f"Unsupported dtype: {s}") - def is_fp8(dt): return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] @@ -348,7 +325,7 @@ def is_fp8(dt): encoding_strategy = strategy_flux.FluxTextEncodingStrategy() # DiT - model = flux_utils.load_flow_model(name, args.ckpt_path, flux_dtype, loading_device) + model = flux_utils.load_flow_model(name, args.ckpt_path, None, loading_device) model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype diff --git a/flux_train_network.py b/flux_train_network.py index 1a40de61a..4a63c2de4 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -29,6 +29,9 @@ def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: logger.warning( "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" @@ -61,9 +64,20 @@ def load_target_model(self, args, weight_dtype, accelerator): name = self.get_flux_model_name(args) # if we load to cpu, flux.to(fp8) takes a long time + if args.fp8_base: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + model = flux_utils.load_flow_model( - name, args.pretrained_model_name_or_path, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors ) + if args.fp8_base: + # check dtype of model + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 FLUX model") if args.split_mode: model = self.prepare_split_model(model, weight_dtype, accelerator) diff --git a/library/flux_utils.py b/library/flux_utils.py index 37166933a..680836168 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,5 +1,5 @@ import json -from typing import Union +from typing import Optional, Union import einops import torch @@ -20,7 +20,9 @@ # temporary copy from sd3_utils TODO refactor -def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32): +def load_safetensors( + path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 +): if disable_mmap: # return safetensors.torch.load(open(path, "rb").read()) # use experimental loader @@ -38,11 +40,13 @@ def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: def load_flow_model( - name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False + name: str, ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False ) -> flux_models.Flux: logger.info(f"Building Flux model {name}") with torch.device("meta"): - model = flux_models.Flux(flux_models.configs[name].params).to(dtype) + model = flux_models.Flux(flux_models.configs[name].params) + if dtype is not None: + model = model.to(dtype) # load_sft doesn't support torch.device logger.info(f"Loading state dict from {ckpt_path}") diff --git a/library/utils.py b/library/utils.py index a16209979..d355cb109 100644 --- a/library/utils.py +++ b/library/utils.py @@ -82,6 +82,66 @@ def setup_logging(args=None, log_level=None, reset=False): logger.info(msg_init) +def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: + """ + Convert a string to a torch.dtype + + Args: + s: string representation of the dtype + default_dtype: default dtype to return if s is None + + Returns: + torch.dtype: the corresponding torch.dtype + + Raises: + ValueError: if the dtype is not supported + + Examples: + >>> str_to_dtype("float32") + torch.float32 + >>> str_to_dtype("fp32") + torch.float32 + >>> str_to_dtype("float16") + torch.float16 + >>> str_to_dtype("fp16") + torch.float16 + >>> str_to_dtype("bfloat16") + torch.bfloat16 + >>> str_to_dtype("bf16") + torch.bfloat16 + >>> str_to_dtype("fp8") + torch.float8_e4m3fn + >>> str_to_dtype("fp8_e4m3fn") + torch.float8_e4m3fn + >>> str_to_dtype("fp8_e4m3fnuz") + torch.float8_e4m3fnuz + >>> str_to_dtype("fp8_e5m2") + torch.float8_e5m2 + >>> str_to_dtype("fp8_e5m2fnuz") + torch.float8_e5m2fnuz + """ + if s is None: + return default_dtype + if s in ["bf16", "bfloat16"]: + return torch.bfloat16 + elif s in ["fp16", "float16"]: + return torch.float16 + elif s in ["fp32", "float32", "float"]: + return torch.float32 + elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: + return torch.float8_e4m3fn + elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: + return torch.float8_e4m3fnuz + elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: + return torch.float8_e5m2 + elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: + return torch.float8_e5m2fnuz + elif s in ["fp8", "float8"]: + return torch.float8_e4m3fn # default fp8 + else: + raise ValueError(f"Unsupported dtype: {s}") + + def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): """ memory efficient save file @@ -198,7 +258,7 @@ def _deserialize_tensor(self, tensor_bytes, metadata): if tensor_bytes is None: byte_tensor = torch.empty(0, dtype=torch.uint8) else: - tensor_bytes = bytearray(tensor_bytes) # make it writable + tensor_bytes = bytearray(tensor_bytes) # make it writable byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8) # process float8 types diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index d5e82920d..2e0d4c297 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -8,7 +8,7 @@ from safetensors.torch import load_file, save_file from tqdm import tqdm -from library.utils import setup_logging +from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file setup_logging() import logging @@ -34,18 +34,23 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, state_dict, dtype, metadata): +def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False): if dtype is not None: logger.info(f"converting to {dtype}...") - for key in list(state_dict.keys()): + for key in tqdm(list(state_dict.keys())): if type(state_dict[key]) == torch.Tensor: state_dict[key] = state_dict[key].to(dtype) logger.info(f"saving to: {file_name}") - save_file(state_dict, file_name, metadata=metadata) + if mem_eff_save: + mem_eff_save_file(state_dict, file_name, metadata=metadata) + else: + save_file(state_dict, file_name, metadata=metadata) -def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): +def merge_to_flux_model( + loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False +): # create module map without loading state_dict logger.info(f"loading keys from FLUX.1 model: {flux_model}") lora_name_to_module_key = {} @@ -57,7 +62,14 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") lora_name_to_module_key[lora_name] = key - flux_state_dict = load_file(flux_model, device=loading_device) + if mem_eff_load_save: + flux_state_dict = {} + with MemoryEfficientSafeOpen(flux_model) as flux_file: + for key in tqdm(flux_file.keys()): + flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + else: + flux_state_dict = load_file(flux_model, device=loading_device) + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU @@ -120,9 +132,17 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati return flux_state_dict -def merge_to_flux_model_diffusers(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): +def merge_to_flux_model_diffusers( + loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False +): logger.info(f"loading keys from FLUX.1 model: {flux_model}") - flux_state_dict = load_file(flux_model, device=loading_device) + if mem_eff_load_save: + flux_state_dict = {} + with MemoryEfficientSafeOpen(flux_model) as flux_file: + for key in tqdm(flux_file.keys()): + flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + else: + flux_state_dict = load_file(flux_model, device=loading_device) def create_key_map(n_double_layers, n_single_layers): key_map = {} @@ -474,19 +494,15 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): def merge(args): + if args.models is None: + args.models = [] + if args.ratios is None: + args.ratios = [] + assert len(args.models) == len( args.ratios ), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" - def str_to_dtype(p): - if p == "float": - return torch.float - if p == "fp16": - return torch.float16 - if p == "bf16": - return torch.bfloat16 - return None - merge_dtype = str_to_dtype(args.precision) save_dtype = str_to_dtype(args.save_precision) if save_dtype is None: @@ -500,11 +516,25 @@ def str_to_dtype(p): if args.flux_model is not None: if not args.diffusers: state_dict = merge_to_flux_model( - args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + args.loading_device, + args.working_device, + args.flux_model, + args.models, + args.ratios, + merge_dtype, + save_dtype, + args.mem_eff_load_save, ) else: state_dict = merge_to_flux_model_diffusers( - args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + args.loading_device, + args.working_device, + args.flux_model, + args.models, + args.ratios, + merge_dtype, + save_dtype, + args.mem_eff_load_save, ) if args.no_metadata: @@ -517,7 +547,7 @@ def str_to_dtype(p): ) logger.info(f"saving FLUX model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, sai_metadata) + save_to_file(args.save_to, state_dict, save_dtype, sai_metadata, args.mem_eff_load_save) else: state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) @@ -546,14 +576,14 @@ def setup_parser() -> argparse.ArgumentParser: "--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], - help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", + help="precision in saving, same to merging if omitted. supported types: " + "float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz" + " / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", ) parser.add_argument( "--precision", type=str, default="float", - choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", ) parser.add_argument( @@ -562,6 +592,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする", ) + parser.add_argument( + "--mem_eff_load_save", + action="store_true", + help="use custom memory efficient load and save functions for FLUX.1 model" + " / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する", + ) parser.add_argument( "--loading_device", type=str, From a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 27 Aug 2024 21:44:10 +0900 Subject: [PATCH 110/348] update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0108ada59..7b1d9cc6c 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The command to install PyTorch is as follows: Aug 27, 2024 (update 2): In FLUX.1 LoRA training, when `--fp8_base` is specified, the FLUX.1 model file with fp8 (`float8_e4m3fn` type) can be loaded directly. Also, in `flux_minimal_inference.py`, it is possible to load it by specifying `fp8 (float8_e4m3fn)` in `--flux_dtype`. -In `flux_merge_lora.py`, you can now specify the precision at save time with `fp8` (see `--help` for details). Also, if you do not specify the merge model, only the model type conversion will be performed. +In `flux_merge_lora.py`, you can now specify `fp8` for the save precision (see `--help` for details). Also, if you do not specify the merge model, only the dtype conversion will be performed. Aug 27, 2024: From 6c0e8a5a1740dbd50a0a45ec1f08983877605cd7 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 29 Aug 2024 14:50:29 +0800 Subject: [PATCH 111/348] make guidance_scale keep float in args --- flux_train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flux_train_network.py b/flux_train_network.py index 4a63c2de4..354a8c6f3 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -324,7 +324,8 @@ def get_noise_pred_and_target( img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) # get guidance - guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) + # ensure guidance_scale in args is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # ensure the hidden state will require grad if args.gradient_checkpointing: From a0cfb0894c4be4ea27412e4c12ed13f68b57094b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 29 Aug 2024 21:20:33 +0900 Subject: [PATCH 112/348] Cleaned up README --- README.md | 281 +++++++++++++++++++++++++++--------------------------- 1 file changed, 143 insertions(+), 138 deletions(-) diff --git a/README.md b/README.md index 7b1d9cc6c..a73eead0b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ This repository contains training, generation and utility scripts for Stable Diffusion. -## FLUX.1 LoRA training (WIP) +## FLUX.1 training (WIP) This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. @@ -9,127 +9,24 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` -Aug 27, 2024 (update 2): -In FLUX.1 LoRA training, when `--fp8_base` is specified, the FLUX.1 model file with fp8 (`float8_e4m3fn` type) can be loaded directly. Also, in `flux_minimal_inference.py`, it is possible to load it by specifying `fp8 (float8_e4m3fn)` in `--flux_dtype`. +- [FLUX.1 LoRA training](#flux1-lora-training) + - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) + - [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) + - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) +- [FLUX.1 fine-tuning](#flux1-fine-tuning) + - [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning) +- [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models) +- [Convert FLUX LoRA](#convert-flux-lora) +- [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) +- [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) -In `flux_merge_lora.py`, you can now specify `fp8` for the save precision (see `--help` for details). Also, if you do not specify the merge model, only the dtype conversion will be performed. - -Aug 27, 2024: - -- FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI. - - `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. -- `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. - -- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is automatically enabled. - -Aug 25, 2024: -Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. -Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` - -Aug 24, 2024 (update 2): - -__Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). - -The number of parameters may increase slightly, so the expressiveness may increase, but the training time may be longer. No detailed verification has been done. - -This implementation is experimental, so it may be deprecated or changed in the future. - -The .safetensors file of the trained model is compatible with the normal LoRA model of sd-scripts, so it should be usable in inference environments such as ComfyUI as it is. Also, converting it to AI-toolkit (Diffusers) format with `convert_flux_lora.py` will reduce the size. It should be no problem to convert it if you use it in the inference environment. - -Technical details: In the implementation of Black Forest Labs' model, the projection layers of q/k/v (and txt in single blocks) are concatenated into one. If LoRA is added there as it is, the LoRA module is only one, and the dimension is large. In contrast, in the implementation of Diffusers, the projection layers of q/k/v/txt are separated. Therefore, the LoRA module is applied to q/k/v/txt separately, and the dimension is smaller. This option is for training LoRA similar to the latter. - -The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. - -Aug 24, 2024: -Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified. - -Aug 22, 2024 (update 2): -Fixed a bug that the embedding was zero-padded when `--apply_t5_attn_mask` option was applied. Also, the cache file for text encoder outputs now records whether the mask is applied or not. Please note that the cache file will be recreated when switching the `--apply_t5_attn_mask` option. - -Added a script to extract LoRA from the difference between the two models of FLUX.1. Use `networks/flux_extract_lora.py`. See `--help` for details. Normally, more than 50GB of memory is required, but specifying the `--mem_eff_safe_open` option significantly reduces memory usage. However, this option is a custom implementation, so unexpected problems may occur. Please always check if the model is loaded correctly. - -Aug 22, 2024: -Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. - -`--disable_mmap_load_safetensors` option now works in `flux_train.py`. It speeds up model loading during training in WSL2. It is also effective in reducing memory usage when loading models during multi-GPU training. Please always check if the model is loaded correctly, as it uses a custom implementation of safetensors loading. - - -Aug 21, 2024 (update 3): -- There is a bug that `--full_bf16` option is enabled even if it is not specified in `flux_train.py`. The bug will be fixed sooner. __Please specify the `--full_bf16` option explicitly, especially when training with 24GB VRAM.__ -- Stochastic rounding is now implemented when `--fused_backward_pass` is specified. The implementation is -based on the code provided by 2kpr. Thank you so much! - - With this change, `--fused_backward_pass` is recommended over `--blockwise_fused_optimizers` when `--full_bf16` is specified. - - Please note that `--fused_backward_pass` is only supported with Adafactor. -- The sample command in [FLUX.1 fine-tuning](#flux1-fine-tuning) is updated to reflect these changes. -- Fixed `--single_blocks_to_swap` is not working in `flux_train.py`. - -Aug 21, 2024 (update 2): -Fixed an error in applying mask in Attention. The attention mask was float, but it should be bool. - -Added a script `convert_flux_lora.py` to convert LoRA between sd-scripts format (BFL-based) and AI-toolkit format (Diffusers-based). See `--help` for details. BFL-based LoRA has a large module, so converting it to Diffusers format may reduce temporary memory usage in the inference environment. Note that re-conversion will increase the size of LoRA. - - -Aug 21, 2024: -The specification of `--apply_t5_attn_mask` has been changed. Previously, the T5 output was zero-padded, but now, two steps are taken: "1. Apply mask when encoding T5" and "2. Apply mask in the attention of Double Block". Fine tuning, LoRA training, and inference in `flux_mini_inference.py` have been changed. - -Aug 20, 2024 (update 3): -__Experimental__ The multi-resolution training is now supported with caching latents to disk. - -The cache files now hold latents for multiple resolutions. Since the latents are appended to the current cache file, it is recommended to delete the cache file in advance (if not, the old latents is kept in .npz file). - -See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. - -Aug 20, 2024 (update 2): -`flux_merge_lora.py` now supports LoRA from AI-toolkit (Diffusers based keys). Specify `--diffusers` option to merge LoRA with Diffusers based keys. Thanks to exveria1015! - -Aug 20, 2024: -FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution). - -The script seems to support multi-resolution even in the current version, ~~if `--cache_latents_to_disk` is not specified~~ -> `--cache_latents_to_disk` is now supported for multi-resolution training. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. - -We will support multi-resolution caching to disk in the near future. - -Aug 19, 2024: -In `flux_train.py`, the memory consumption during model saving is reduced when `--save_precision` is set to the same value as `--mixed_precision` (about 22GB). Please set the same value unless there is a reason. - -An experimental option `--mem_eff_save` is also added. When specified, it can further reduce memory consumption (about 22GB), but since it is a custom implementation, unexpected problems may occur. We do not recommend using it unless you are familiar with the code. - -Aug 18, 2024: -Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr! See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. - -Aug 17, 2024: -Added a script `flux_train.py` to train FLUX.1. The script is experimental and not an optimized version. It needs >28GB VRAM for training. - -Aug 16, 2024: - -Added a script `networks/flux_merge_lora.py` to merge LoRA into FLUX.1 checkpoint. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. - -FLUX.1 schnell model based training is now supported (but not tested). If the name of the model file contains `schnell`, the model is treated as a schnell model. - -Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. The default is 512 in dev and 256 in schnell. - -Previously, when `--max_token_length` was specified, that value was used, and 512 was used when omitted (default). Therefore, there is no impact if `--max_token_length` was not specified. If `--max_token_length` was specified, please specify `--t5xxl_max_token_length` instead. `--max_token_length` is ignored during FLUX.1 training. - -Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified. - -Aug 13, 2024: - -__Experimental__ A network argument `train_blocks` is added to `lora_flux`. This is to select the target blocks of LoRA from FLUX double blocks and single blocks. Specify like `--network_args "train_blocks=single"`. `all` trains both double blocks and single blocks, `double` trains only double blocks, and `single` trains only single blocks. The default (omission) is `all`. - -This argument is available even if `--split_mode` is not specified. - -__Experimental__ `--split_mode` option is added to `flux_train_network.py`. This splits FLUX into double blocks and single blocks for training. By enabling gradients only for the single blocks part, memory usage is reduced. When this option is specified, you need to specify `"train_blocks=single"` in the network arguments. - -This option enables training with 12GB VRAM GPUs, but the training speed is 2-3 times slower than the default. - -Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before. - -Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. +### FLUX.1 LoRA training +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. -### FLUX.1 LoRA training +FLUX.1 model, CLIP-L, and T5XXL models are recommended to be in bf16/fp16 format. If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format. -We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. +Sample command is below. It will work with 24GB VRAM GPUs. ``` accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py @@ -137,45 +34,106 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 ---network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base +--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml ---output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid ---model_prediction_type raw --guidance_scale 1.0 --loss_type l2 +--output_dir path/to/output/dir --output_name flux-lora-name +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 ``` (The command is multi-line for readability. Please combine it into one line.) The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: ``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_mode` and `train_blocks=single` options. Please use settings like below: ``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` -LoRAs for Text Encoders are not tested yet. +We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. + +The trained LoRA model can be used with ComfyUI. + +#### Key Options for FLUX.1 LoRA training -We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_scale`, `--model_prediction_type` and `--discrete_flow_shift`. The options are as follows: +There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome. -- `--timestep_sampling` is the method to sample timesteps (0-1): `sigma` (sigma-based, same as SD3), `uniform` (uniform random), or `sigmoid` (sigmoid of random normal, same as x-flux). +- `--timestep_sampling` is the method to sample timesteps (0-1): + - `sigma`: sigma-based, same as SD3 + - `uniform`: uniform random + - `sigmoid`: sigmoid of random normal, same as x-flux, AI-toolkit etc. + - `shift`: shifts the value of sigmoid of normal distribution random number - `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform. -- `--model_prediction_type` is how to interpret and process the model prediction: `raw` (use as is, same as x-flux), `additive` (add to noisy input), `sigma_scaled` (apply sigma scaling, same as SD3). + - This option is effective even when`--timestep_sampling shift` is specified. + - Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. +- `--model_prediction_type` is how to interpret and process the model prediction: + - `raw`: use as is, same as x-flux + - `additive`: add to noisy input + - `sigma_scaled`: apply sigma scaling, same as SD3 - `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3). -`--loss_type` may be useful for FLUX.1 training. The default is `l2`. +The existing `--loss_type` option may be useful for FLUX.1 training. The default is `l2`. -In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. +~~In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. ~~ -additional note (Aug 11): A quick check shows that the settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). This seems to be a good starting point. Thanks to Ostris for the great work! +In our experiments, `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` with `--loss_type l2` seems to work better than other settings. + +The settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). Other settings may work better, so please try different settings. -We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. +Other options are described below. -The trained LoRA model can be used with ComfyUI. +#### Distribution of timesteps + +`--timestep_sampling` and `--sigmoid_scale`, `--discrete_flow_shift` adjust the distribution of timesteps. The distribution is shown in the figures below. + +The effect of `--discrete_flow_shift` with `--timestep_sampling shift` (when `--sigmoid_scale` is not specified, the default is 1.0): + +The difference between `--timestep_sampling uniform` and `--timestep_sampling sigma`: + +The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--timestep_sampling sigmoid` is specified, `--discrete_flow_shift` is ignored): + +#### Key Features for FLUX.1 LoRA training + +1. CLIP-L LoRA Support: + - FLUX.1 LoRA training now supports CLIP-L LoRA. + - Remove `--network_train_unet_only` from your command. + - T5XXL is not trained. Its output is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. + - The trained LoRA can be used with ComfyUI. + - Note: `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. + +2. Experimental FP8/FP16 mixed training: + - `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L. + - FLUX can be trained with fp8, and CLIP-L can be trained with bf16/fp16. + - When specifying this option, the `--fp8_base` option is automatically enabled. + +3. Split Q/K/V Projection Layers (Experimental): + - Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them. + - Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). + - May increase expressiveness but also training time. + - The trained model is compatible with normal LoRA models in sd-scripts and can be used in environments like ComfyUI. + - Converting to AI-toolkit (Diffusers) format with `convert_flux_lora.py` will reduce the size. + +4. T5 Attention Mask Application: + - T5 attention mask is applied when `--apply_t5_attn_mask` is specified. + - Now applies mask when encoding T5 and in the attention of Double and Single Blocks + - Affects fine-tuning, LoRA training, and inference in `flux_minimal_inference.py`. + +5. Multi-resolution Training Support: + - FLUX.1 now supports multi-resolution training, even with caching latents to disk. + + +Technical details of Q/K/V split: + +In the implementation of Black Forest Labs' model, the projection layers of q/k/v (and txt in single blocks) are concatenated into one. If LoRA is added there as it is, the LoRA module is only one, and the dimension is large. In contrast, in the implementation of Diffusers, the projection layers of q/k/v/txt are separated. Therefore, the LoRA module is applied to q/k/v/txt separately, and the dimension is smaller. This option is for training LoRA similar to the latter. + +The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. + +### Inference for FLUX.1 with LoRA model The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. @@ -185,6 +143,8 @@ python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safete ### FLUX.1 fine-tuning +The memory-efficient training with block swap is based on 2kpr's implementation. Thanks to 2kpr! + Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended. ``` @@ -195,15 +155,13 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name --learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1 --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" ---timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 +--lr_scheduler constant_with_warmup --max_grad_norm 0.0 +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --fused_backward_pass --double_blocks_to_swap 6 --cpu_offload_checkpointing --full_bf16 ``` +(The command is multi-line for readability. Please combine it into one line.) -(Combine the command into one line.) - -Sample image generation during training is not tested yet. - -Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. +Options are almost the same as LoRA training. The difference is `--full_bf16`, `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. `--full_bf16` enables the training with bf16 (weights and gradients). @@ -223,6 +181,53 @@ Swap 6 double blocks and use cpu offload checkpointing may be a good starting po The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results. +#### Key Features for FLUX.1 fine-tuning + +1. Sample Image Generation: + - Sample image generation during training is now supported. + - The prompts are cached and used for generation if `--cache_latents` is specified. So changing the prompts during training will not affect the generated images. + - Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. + - Note: It will be very slow when `--split_mode` is specified. + +2. Experimental Memory-Efficient Saving: + - `--mem_eff_save` option can further reduce memory consumption during model saving (about 22GB). + - This is a custom implementation and may cause unexpected issues. Use with caution. + +3. T5XXL Token Length Control: + - Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. + - Default is 512 in dev and 256 in schnell models. + +4. Multi-GPU Training Support: + - Note: `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. + +5. Disable mmap Load for Safetensors: + - `--disable_mmap_load_safetensors` option now works in `flux_train.py`. + - Speeds up model loading during training in WSL2. + - Effective in reducing memory usage when loading models during multi-GPU training. + + +### Extract LoRA from FLUX.1 Models + +Script: `networks/flux_extract_lora.py` + +Extracts LoRA from the difference between two FLUX.1 models. + +Offers memory-efficient option with `--mem_eff_safe_open`. + +CLIP-L LoRA is not supported. + +### Convert FLUX LoRA + +Script: `convert_flux_lora.py` + +Converts LoRA between sd-scripts format (BFL-based) and AI-toolkit format (Diffusers-based). + +If you use LoRA in the inference environment, converting it to AI-toolkit format may reduce temporary memory usage. + +Note that re-conversion will increase the size of LoRA. + +CLIP-L LoRA is not supported. + ### Merge LoRA to FLUX.1 checkpoint `networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ From daa6ad516581872aa6acaa15c0d24aad4f998838 Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Thu, 29 Aug 2024 21:25:30 +0900 Subject: [PATCH 113/348] Update README.md --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a73eead0b..6e2ae3376 100644 --- a/README.md +++ b/README.md @@ -77,9 +77,9 @@ There are many unknown points in FLUX.1 training, so some settings can be specif The existing `--loss_type` option may be useful for FLUX.1 training. The default is `l2`. -~~In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. ~~ +~~In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted.~~ -In our experiments, `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` with `--loss_type l2` seems to work better than other settings. +In our experiments, `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type) seems to work better. The settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). @@ -92,10 +92,13 @@ Other options are described below. `--timestep_sampling` and `--sigmoid_scale`, `--discrete_flow_shift` adjust the distribution of timesteps. The distribution is shown in the figures below. The effect of `--discrete_flow_shift` with `--timestep_sampling shift` (when `--sigmoid_scale` is not specified, the default is 1.0): +![Figure_2](https://github.com/user-attachments/assets/d9de42f9-f17d-40da-b88d-d964402569c6) -The difference between `--timestep_sampling uniform` and `--timestep_sampling sigma`: +The difference between `--timestep_sampling sigmoid` and `--timestep_sampling uniform` (when `--timestep_sampling sigmoid` or `uniform` is specified, `--discrete_flow_shift` is ignored): +![Figure_3](https://github.com/user-attachments/assets/27029009-1f5d-4dc0-bb24-13d02ac4fdad) The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--timestep_sampling sigmoid` is specified, `--discrete_flow_shift` is ignored): +![Figure_4](https://github.com/user-attachments/assets/08a2267c-e47e-48b7-826e-f9a080787cdc) #### Key Features for FLUX.1 LoRA training From 8ecf0fc4bfd1b03cfc6fd4055af0b3363f5d1f38 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 29 Aug 2024 22:10:57 +0900 Subject: [PATCH 114/348] Refactor code to ensure args.guidance_scale is always a float #1525 --- flux_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train.py b/flux_train.py index 410728d44..32a36f036 100644 --- a/flux_train.py +++ b/flux_train.py @@ -688,8 +688,8 @@ def optimizer_hook(parameter: torch.Tensor): packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) - # get guidance - guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # call model l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds From 8fdfd8c857a88aaa78ac9c2488432ef8115982f2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 29 Aug 2024 22:26:29 +0900 Subject: [PATCH 115/348] Update safetensors to version 0.4.4 in requirements.txt #1524 --- README.md | 7 +++++++ requirements.txt | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 6e2ae3376..30264e738 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,13 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +### Recent Updates + +Aug 29, 2024: +Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated. + +### Contents + - [FLUX.1 LoRA training](#flux1-lora-training) - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) - [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) diff --git a/requirements.txt b/requirements.txt index 4ee19b3ee..4c1bc3922 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ bitsandbytes==0.43.3 prodigyopt==1.0 lion-pytorch==0.0.6 tensorboard -safetensors==0.4.2 +safetensors==0.4.4 # gradio==3.16.2 altair==4.2.2 easygui==0.98.3 From 34f2315047f8d5b89b7a8a6093bb56679bff13c3 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 29 Aug 2024 22:33:37 +0800 Subject: [PATCH 116/348] fix: text_encoder_conds referenced before assignment --- train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 048c7e7bd..628c421cb 100644 --- a/train_network.py +++ b/train_network.py @@ -1081,12 +1081,12 @@ def remove_model(old_ckpt_name): # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) + text_encoder_conds = [] text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs if ( - text_encoder_conds is None - or len(text_encoder_conds) == 0 + len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder ): From 35882f8d5bbd076a97622cf6193c988621481803 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 29 Aug 2024 23:03:43 +0800 Subject: [PATCH 117/348] fix --- train_network.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index 628c421cb..4204bce34 100644 --- a/train_network.py +++ b/train_network.py @@ -1112,10 +1112,14 @@ def remove_model(old_ckpt_name): if args.full_fp16: encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] - # if encoded_text_encoder_conds is not None, update cached text_encoder_conds - for i in range(len(encoded_text_encoder_conds)): - if encoded_text_encoder_conds[i] is not None: - text_encoder_conds[i] = encoded_text_encoder_conds[i] + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( From 25c9040f4fbbcbddc0297895369337846152fea4 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 31 Aug 2024 03:05:19 +0800 Subject: [PATCH 118/348] Update flux_train_utils.py --- library/flux_train_utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index a8e94ac00..735bcced7 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -371,7 +371,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents, noise, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - bsz = latents.shape[0] + bsz, _, H, W = latents.shape sigmas = None if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": @@ -392,6 +392,16 @@ def get_noisy_model_input_and_timesteps( timesteps = logits_norm.sigmoid() timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "flux_shift": + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + mu=get_lin_function(y1=0.5, y2=1.15)((H//2) * (W//2)) + timesteps = time_shift(mu, 1.0, timesteps) + t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 noisy_model_input = (1 - t) * latents + t * noise @@ -571,7 +581,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--timestep_sampling", - choices=["sigma", "uniform", "sigmoid", "shift"], + choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"], default="sigma", help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", From ef510b3cb94427d72df681389e1214251813b1a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Sun, 1 Sep 2024 17:41:01 +0800 Subject: [PATCH 119/348] Sd3 freeze x_block (#1417) * Update sd3_train.py * add freeze block lr * Update train_util.py * update --- library/train_util.py | 21 +++++++++++++++++++++ sd3_train.py | 9 ++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 989758ad5..74aae0a79 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3246,6 +3246,12 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): default=None, help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", ) + parser.add_argument( + "--num_last_block_to_freeze", + type=int, + default=None, + help="num_last_block_to_freeze", + ) def add_optimizer_arguments(parser: argparse.ArgumentParser): @@ -5758,6 +5764,21 @@ def sample_image_inference( pass +def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"): + + filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name] + print(f"filtered_blocks: {len(filtered_blocks)}") + + num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze) + + print(f"freeze_blocks: {num_blocks_to_freeze}") + + start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) + + for i in range(start_freezing_from, len(filtered_blocks)): + _, param = filtered_blocks[i] + param.requires_grad = False + # endregion diff --git a/sd3_train.py b/sd3_train.py index 3b6c8a118..ce9500b0b 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -368,12 +368,19 @@ def train(args): vae.eval() vae.to(accelerator.device, dtype=vae_dtype) + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared + + if args.num_last_block_to_freeze: + train_util.freeze_blocks(mmdit,num_last_block_to_freeze=args.num_last_block_to_freeze) + training_models = [] params_to_optimize = [] # if train_unet: training_models.append(mmdit) # if block_lrs is None: - params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate}) + params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate}) # else: # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) From 92e7600cc2fea604321004f260e7db76c764f388 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Sep 2024 18:57:07 +0900 Subject: [PATCH 120/348] Move freeze_blocks to sd3_train because it's only for sd3 --- README.md | 3 +++ library/train_util.py | 21 --------------------- sd3_train.py | 22 ++++++++++++++++++++-- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 30264e738..d96367194 100644 --- a/README.md +++ b/README.md @@ -309,6 +309,9 @@ resolution = [512, 512] SD3 training is done with `sd3_train.py`. +__Sep 1, 2024__: +- `--num_last_block_to_freeze` is added to `sd3_train.py`. This option is to freeze the last n blocks of the MMDiT. See [#1417](https://github.com/kohya-ss/sd-scripts/pull/1417) for details. Thanks to sdbds! + __Jul 27, 2024__: - Latents and text encoder outputs caching mechanism is refactored significantly. - Existing cache files for SD3 need to be recreated. Please delete the previous cache files. diff --git a/library/train_util.py b/library/train_util.py index 74aae0a79..989758ad5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3246,12 +3246,6 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): default=None, help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", ) - parser.add_argument( - "--num_last_block_to_freeze", - type=int, - default=None, - help="num_last_block_to_freeze", - ) def add_optimizer_arguments(parser: argparse.ArgumentParser): @@ -5764,21 +5758,6 @@ def sample_image_inference( pass -def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"): - - filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name] - print(f"filtered_blocks: {len(filtered_blocks)}") - - num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze) - - print(f"freeze_blocks: {num_blocks_to_freeze}") - - start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) - - for i in range(start_freezing_from, len(filtered_blocks)): - _, param = filtered_blocks[i] - param.requires_grad = False - # endregion diff --git a/sd3_train.py b/sd3_train.py index ce9500b0b..87011b215 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -373,7 +373,20 @@ def train(args): mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared if args.num_last_block_to_freeze: - train_util.freeze_blocks(mmdit,num_last_block_to_freeze=args.num_last_block_to_freeze) + # freeze last n blocks of MM-DIT + block_name = "x_block" + filtered_blocks = [(name, param) for name, param in mmdit.named_parameters() if block_name in name] + accelerator.print(f"filtered_blocks: {len(filtered_blocks)}") + + num_blocks_to_freeze = min(len(filtered_blocks), args.num_last_block_to_freeze) + + accelerator.print(f"freeze_blocks: {num_blocks_to_freeze}") + + start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) + + for i in range(start_freezing_from, len(filtered_blocks)): + _, param = filtered_blocks[i] + param.requires_grad = False training_models = [] params_to_optimize = [] @@ -1033,12 +1046,17 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", ) - parser.add_argument( "--skip_latents_validity_check", action="store_true", help="skip latents validity check / latentsの正当性チェックをスキップする", ) + parser.add_argument( + "--num_last_block_to_freeze", + type=int, + default=None, + help="freeze last n blocks of MM-DIT / MM-DITの最後のnブロックを凍結する", + ) return parser From 4f6d915d15262447b1049a78a55678b2825784a3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Sep 2024 19:12:29 +0900 Subject: [PATCH 121/348] update help and README --- README.md | 5 +++++ library/flux_train_utils.py | 8 ++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d96367194..331951ef4 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 1, 2024: +- `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds! + - This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`. + Aug 29, 2024: Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated. @@ -73,6 +77,7 @@ There are many unknown points in FLUX.1 training, so some settings can be specif - `uniform`: uniform random - `sigmoid`: sigmoid of random normal, same as x-flux, AI-toolkit etc. - `shift`: shifts the value of sigmoid of normal distribution random number + - `flux_shift`: shifts the value of sigmoid of normal distribution random number, depending on the resolution (same as FLUX.1 dev inference). `--discrete_flow_shift` is ignored when `flux_shift` is specified. - `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform. - This option is effective even when`--timestep_sampling shift` is specified. - Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 735bcced7..9dad4baa2 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -371,7 +371,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents, noise, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - bsz, _, H, W = latents.shape + bsz, _, h, w = latents.shape sigmas = None if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": @@ -399,7 +399,7 @@ def get_noisy_model_input_and_timesteps( logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() - mu=get_lin_function(y1=0.5, y2=1.15)((H//2) * (W//2)) + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) timesteps = time_shift(mu, 1.0, timesteps) t = timesteps.view(-1, 1, 1, 1) @@ -583,8 +583,8 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): "--timestep_sampling", choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"], default="sigma", - help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." - " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。", ) parser.add_argument( "--sigmoid_scale", From 6abacf04da756808ffca567f6660445ecdf478bd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 2 Sep 2024 13:05:26 +0900 Subject: [PATCH 122/348] update README --- README.md | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 331951ef4..5dd916aa0 100644 --- a/README.md +++ b/README.md @@ -184,7 +184,7 @@ Options are almost the same as LoRA training. The difference is `--full_bf16`, ` `--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. -`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. +`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. Please see the next chapter for details. `--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. @@ -198,24 +198,32 @@ The learning rate and the number of epochs are not optimized yet. Please adjust #### Key Features for FLUX.1 fine-tuning -1. Sample Image Generation: +1. Technical details of double/single block swap: + - Reduce memory usage by transferring double and single blocks of FLUX.1 from GPU to CPU when they are not needed. + - During forward pass, the weights of the blocks that have finished calculation are transferred to CPU, and the weights of the blocks to be calculated are transferred to GPU. + - The same is true for the backward pass, but the order is reversed. The gradients remain on the GPU. + - Since the transfer between CPU and GPU takes time, the training will be slower. + - `--double_blocks_to_swap` and `--single_blocks_to_swap` specify the number of blocks to swap. For example, `--double_blocks_to_swap 6` swaps 6 blocks at each step of training, but the remaining 13 blocks are always on the GPU. + - About 640MB of memory can be saved per double block, and about 320MB of memory can be saved per single block. + +2. Sample Image Generation: - Sample image generation during training is now supported. - The prompts are cached and used for generation if `--cache_latents` is specified. So changing the prompts during training will not affect the generated images. - Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. - Note: It will be very slow when `--split_mode` is specified. -2. Experimental Memory-Efficient Saving: +3. Experimental Memory-Efficient Saving: - `--mem_eff_save` option can further reduce memory consumption during model saving (about 22GB). - This is a custom implementation and may cause unexpected issues. Use with caution. -3. T5XXL Token Length Control: +4. T5XXL Token Length Control: - Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. - Default is 512 in dev and 256 in schnell models. -4. Multi-GPU Training Support: +5. Multi-GPU Training Support: - Note: `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. -5. Disable mmap Load for Safetensors: +6. Disable mmap Load for Safetensors: - `--disable_mmap_load_safetensors` option now works in `flux_train.py`. - Speeds up model loading during training in WSL2. - Effective in reducing memory usage when loading models during multi-GPU training. From b65ae9b439e4324359014d6d720aa01def3a19fc Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 4 Sep 2024 21:33:17 +0900 Subject: [PATCH 123/348] T5XXL LoRA training, fp8 T5XXL support --- README.md | 45 +++++++++++---- flux_train_network.py | 112 +++++++++++++++++++++++++++++------- library/flux_train_utils.py | 23 ++++++-- library/flux_utils.py | 9 ++- library/strategy_flux.py | 13 ++++- networks/lora_flux.py | 39 ++++++++++--- train_network.py | 48 ++++++++++------ 7 files changed, 222 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index 5dd916aa0..840655705 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,11 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 4, 2024: +- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. +- In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching. +- Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training. + Sep 1, 2024: - `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds! - This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`. @@ -41,8 +46,8 @@ Sample command is below. It will work with 24GB VRAM GPUs. ``` accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py ---pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors ---ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers +--pretrained_model_name_or_path flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors +--ae ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base @@ -72,6 +77,11 @@ The trained LoRA model can be used with ComfyUI. There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome. +- `--pretrained_model_name_or_path` is the path to the pretrained model (FLUX.1). bf16 (original BFL model) is recommended (`flux1-dev.safetensors` or `flux1-dev.sft`). If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format. +- `--clip_l` is the path to the CLIP-L model. +- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching. +- `--ae` is the path to the autoencoder model (`ae.safetensors` or `ae.sft`). + - `--timestep_sampling` is the method to sample timesteps (0-1): - `sigma`: sigma-based, same as SD3 - `uniform`: uniform random @@ -114,16 +124,29 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times #### Key Features for FLUX.1 LoRA training -1. CLIP-L LoRA Support: - - FLUX.1 LoRA training now supports CLIP-L LoRA. +1. CLIP-L and T5XXL LoRA Support: + - FLUX.1 LoRA training now supports CLIP-L and T5XXL LoRA training. - Remove `--network_train_unet_only` from your command. - - T5XXL is not trained. Its output is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. + - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time. + - T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. - The trained LoRA can be used with ComfyUI. - - Note: `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. + - Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet. + + | trained LoRA|option|network_args|cache_text_encoder_outputs (*1)| + |---|---|---|---| + |FLUX.1|`--network_train_unet_only`|-|o| + |FLUX.1 + CLIP-L|-|-|o (*2)| + |FLUX.1 + CLIP-L + T5XXL|-|`train_t5xxl=True`|-| + |CLIP-L (*3)|`--network_train_text_encoder_only`|-|o (*2)| + |CLIP-L + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-| + + - *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - *2: T5XXL output can be cached for CLIP-L LoRA training. + - *3: Not tested yet. 2. Experimental FP8/FP16 mixed training: - - `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L. - - FLUX can be trained with fp8, and CLIP-L can be trained with bf16/fp16. + - `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L/T5XXL. + - FLUX can be trained with fp8, and CLIP-L/T5XXL can be trained with bf16/fp16. - When specifying this option, the `--fp8_base` option is automatically enabled. 3. Split Q/K/V Projection Layers (Experimental): @@ -153,7 +176,7 @@ The compatibility of the saved model (state dict) is ensured by concatenating th The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. ``` -python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 +python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` ### FLUX.1 fine-tuning @@ -164,7 +187,7 @@ Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GP ``` accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py ---pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft +--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.safetensors --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name @@ -256,7 +279,7 @@ CLIP-L LoRA is not supported. `networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ ``` -python networks/flux_merge_lora.py --flux_model flux1-dev.sft --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu +python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu ``` You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`. diff --git a/flux_train_network.py b/flux_train_network.py index 354a8c6f3..2fc0f3234 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -43,13 +43,9 @@ def assert_extra_args(self, args, train_dataset_group): train_dataset_group.is_text_encoder_output_cacheable() ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - # assert ( - # args.network_train_unet_only or not args.cache_text_encoder_outputs - # ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" - if not args.network_train_unet_only: - logger.info( - "network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません" - ) + # prepare CLIP-L/T5XXL training flags + self.train_clip_l = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") @@ -63,12 +59,10 @@ def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models name = self.get_flux_model_name(args) - # if we load to cpu, flux.to(fp8) takes a long time - if args.fp8_base: - loading_dtype = None # as is - else: - loading_dtype = weight_dtype + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future model = flux_utils.load_flow_model( name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors ) @@ -85,9 +79,21 @@ def load_target_model(self, args, weight_dtype, accelerator): clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) clip_l.eval() + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + # loading t5xxl to cpu takes a long time, so we should load to gpu in future - t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) @@ -154,25 +160,35 @@ def get_latents_caching_strategy(self, args): def get_text_encoding_strategy(self, args): return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + def get_models_for_text_encoding(self, args, accelerator, text_encoders): if args.cache_text_encoder_outputs: - if self.is_train_text_encoder(args): + if self.train_clip_l and not self.train_t5xxl: return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached else: - return text_encoders # ignored + return None # no text encoders are needed for encoding because both are cached else: return text_encoders # both CLIP-L and T5XXL are needed for encoding def get_text_encoders_train_flags(self, args, text_encoders): - return [True, False] if self.is_train_text_encoder(args) else [False, False] + return [self.train_clip_l, self.train_t5xxl] def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True return strategy_flux.FluxTextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, None, False, - is_partial=self.is_train_text_encoder(args), + is_partial=self.train_clip_l or self.train_t5xxl, apply_t5_attn_mask=args.apply_t5_attn_mask, ) else: @@ -193,8 +209,16 @@ def cache_text_encoder_outputs_if_needed( # When TE is not be trained, it will not be prepared so we need to use explicit autocast logger.info("move text encoders to gpu") - text_encoders[0].to(accelerator.device, dtype=weight_dtype) - text_encoders[1].to(accelerator.device, dtype=weight_dtype) + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device) + + if text_encoders[1].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[1].to(weight_dtype) + with accelerator.autocast(): dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) @@ -235,7 +259,7 @@ def cache_text_encoder_outputs_if_needed( else: # Text Encoderから毎回出力を取得するので、GPUに乗せておく text_encoders[0].to(accelerator.device, dtype=weight_dtype) - text_encoders[1].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device) # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -255,9 +279,12 @@ def cache_text_encoder_outputs_if_needed( # return noise_pred def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): + text_encoders = text_encoder # for compatibility + text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + if not args.split_mode: flux_train_utils.sample_images( - accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs + accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs ) return @@ -281,7 +308,7 @@ def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_a wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) clean_memory_on_device(accelerator.device) flux_train_utils.sample_images( - accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs + accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs ) clean_memory_on_device(accelerator.device) @@ -421,6 +448,47 @@ def update_metadata(self, metadata, args): def is_text_encoder_not_needed_for_training(self, args): return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0: # CLIP-L + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0: # CLIP-L + logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 9dad4baa2..0b5d4d90e 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -85,7 +85,7 @@ def sample_images( if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. - with torch.no_grad(): + with torch.no_grad(), accelerator.autocast(): for prompt_dict in prompts: sample_image_inference( accelerator, @@ -187,14 +187,27 @@ def sample_image_inference( tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + text_encoder_conds = [] if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: - te_outputs = sample_prompts_te_outputs[prompt] - else: + text_encoder_conds = sample_prompts_te_outputs[prompt] + print(f"Using cached text encoder outputs for prompt: {prompt}") + if text_encoders is not None: + print(f"Encoding prompt: {prompt}") tokens_and_masks = tokenize_strategy.tokenize(prompt) # strategy has apply_t5_attn_mask option - te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + print([x.shape if x is not None else None for x in encoded_text_encoder_conds]) + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] - l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds # sample image weight_dtype = ae.dtype # TOFO give dtype as argument diff --git a/library/flux_utils.py b/library/flux_utils.py index 680836168..7b0a41a8a 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -171,7 +171,9 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev return clip -def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel: +def load_t5xxl( + ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False +) -> T5EncoderModel: T5_CONFIG_JSON = """ { "architectures": [ @@ -217,6 +219,11 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi return t5xxl +def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype: + # nn.Embedding is the first layer, but it could be casted to bfloat16 or float32 + return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype + + def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None] diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 5d0839132..6c9ef5e4a 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -5,8 +5,7 @@ import numpy as np from transformers import CLIPTokenizer, T5TokenizerFast -from library import sd3_utils, train_util -from library import sd3_models +from library import flux_utils, train_util from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy from library.utils import setup_logging @@ -100,6 +99,8 @@ def __init__( super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) self.apply_t5_attn_mask = apply_t5_attn_mask + self.warn_fp8_weights = False + def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX @@ -144,6 +145,14 @@ def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List ): + if not self.warn_fp8_weights: + if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn: + logger.warning( + "T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs." + " / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。" + ) + self.warn_fp8_weights = True + flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy captions = [info.caption for info in infos] diff --git a/networks/lora_flux.py b/networks/lora_flux.py index fcb56a467..295267beb 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -330,6 +330,11 @@ def create_network( if split_qkv is not None: split_qkv = True if split_qkv == "True" else False + # train T5XXL + train_t5xxl = kwargs.get("train_t5xxl", False) + if train_t5xxl is not None: + train_t5xxl = True if train_t5xxl == "True" else False + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -344,6 +349,7 @@ def create_network( conv_alpha=conv_alpha, train_blocks=train_blocks, split_qkv=split_qkv, + train_t5xxl=train_t5xxl, varbose=True, ) @@ -370,9 +376,10 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh else: weights_sd = torch.load(file, map_location="cpu") - # get dim/alpha mapping + # get dim/alpha mapping, and train t5xxl modules_dim = {} modules_alpha = {} + train_t5xxl = None for key, value in weights_sd.items(): if "." not in key: continue @@ -385,6 +392,12 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_dim[lora_name] = dim # logger.info(lora_name, value.size(), dim) + if train_t5xxl is None: + train_t5xxl = "lora_te3" in lora_name + + if train_t5xxl is None: + train_t5xxl = False + # # split qkv # double_qkv_rank = None # single_qkv_rank = None @@ -413,6 +426,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_alpha=modules_alpha, module_class=module_class, split_qkv=split_qkv, + train_t5xxl=train_t5xxl, ) return network, weights_sd @@ -421,10 +435,10 @@ class LoRANetwork(torch.nn.Module): # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"] LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" - LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible def __init__( self, @@ -443,6 +457,7 @@ def __init__( modules_alpha: Optional[Dict[str, int]] = None, train_blocks: Optional[str] = None, split_qkv: bool = False, + train_t5xxl: bool = False, varbose: Optional[bool] = False, ) -> None: super().__init__() @@ -457,6 +472,7 @@ def __init__( self.module_dropout = module_dropout self.train_blocks = train_blocks if train_blocks is not None else "all" self.split_qkv = split_qkv + self.train_t5xxl = train_t5xxl self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -469,12 +485,16 @@ def __init__( logger.info( f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" ) - if self.conv_lora_dim is not None: - logger.info( - f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" - ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) if self.split_qkv: logger.info(f"split qkv for LoRA") + if self.train_blocks is not None: + logger.info(f"train {self.train_blocks} blocks only") + if train_t5xxl: + logger.info(f"train T5XXL as well") # create module instances def create_modules( @@ -550,12 +570,15 @@ def create_modules( skipped_te = [] for i, text_encoder in enumerate(text_encoders): index = i + if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False + break + logger.info(f"create LoRA for Text Encoder {index+1}:") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # create LoRA for U-Net if self.train_blocks == "all": diff --git a/train_network.py b/train_network.py index 4204bce34..a68ccfcc4 100644 --- a/train_network.py +++ b/train_network.py @@ -157,6 +157,9 @@ def sample_images(self, accelerator, args, epoch, global_step, device, vae, toke # region SD/SDXL + def post_process_network(self, args, accelerator, network, text_encoders, unet): + pass + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False @@ -237,6 +240,13 @@ def update_metadata(self, metadata, args): def is_text_encoder_not_needed_for_training(self, args): return False # use for sample images + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + # set top parameter requires_grad = True for gradient checkpointing works + text_encoder.text_model.embeddings.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + # endregion def train(self, args): @@ -329,7 +339,7 @@ def train(self, args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - self.assert_extra_args(args, train_dataset_group) + self.assert_extra_args(args, train_dataset_group) # may change some args # acceleratorを準備する logger.info("preparing accelerator") @@ -428,12 +438,15 @@ def train(self, args): ) args.scale_weight_norms = False + self.post_process_network(args, accelerator, network, text_encoders, unet) + + # apply network to unet and text_encoder train_unet = not args.network_train_text_encoder_only train_text_encoder = self.is_train_text_encoder(args) network.apply_to(text_encoder, unet, train_text_encoder, train_unet) if args.network_weights is not None: - # FIXME consider alpha of weights + # FIXME consider alpha of weights: this assumes that the alpha is not changed info = network.load_weights(args.network_weights) accelerator.print(f"load network weights from {args.network_weights}: {info}") @@ -533,7 +546,7 @@ def train(self, args): ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" accelerator.print("enable fp8 training for U-Net.") unet_weight_dtype = torch.float8_e4m3fn - + if not args.fp8_base_unet: accelerator.print("enable fp8 training for Text Encoder.") te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn @@ -545,17 +558,16 @@ def train(self, args): unet.requires_grad_(False) unet.to(dtype=unet_weight_dtype) - for t_enc in text_encoders: + for i, t_enc in enumerate(text_encoders): t_enc.requires_grad_(False) # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 if t_enc.device.type != "cpu": t_enc.to(dtype=te_weight_dtype) - if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"): - # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) - elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"): - t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + + # nn.Embedding not support FP8 + if te_weight_dtype != weight_dtype: + self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: @@ -596,12 +608,12 @@ def train(self, args): if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() - for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): + for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))): t_enc.train() # set top parameter requires_grad = True for gradient checkpointing works if frag: - t_enc.text_model.embeddings.requires_grad_(True) + self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc) else: unet.eval() @@ -1028,8 +1040,12 @@ def remove_model(old_ckpt_name): # log device and dtype for each model logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") - for t_enc in text_encoders: - logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}") + for i, t_enc in enumerate(text_encoders): + params_itr = t_enc.parameters() + params_itr.__next__() # skip the first parameter + params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings + param_3rd = params_itr.__next__() + logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}") clean_memory_on_device(accelerator.device) @@ -1085,11 +1101,7 @@ def remove_model(old_ckpt_name): text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - if ( - len(text_encoder_conds) == 0 - or text_encoder_conds[0] is None - or train_text_encoder - ): + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: From b7cff0a7548e5e33f735f06293ba24119fdaa585 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 4 Sep 2024 21:35:47 +0900 Subject: [PATCH 124/348] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 840655705..c0acfa1d2 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The command to install PyTorch is as follows: ### Recent Updates Sep 4, 2024: -- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. +- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. See [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) for details. - In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching. - Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training. From 56cb2fc885d818e9c4493fb2843870d7a141db1c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 4 Sep 2024 23:15:27 +0900 Subject: [PATCH 125/348] support T5XXL LoRA, reduce peak memory usage #1560 --- flux_minimal_inference.py | 73 +++++++++++++++++++++++++++++++-------- networks/lora_flux.py | 2 +- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 56c1b1982..1c194e7c1 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -5,7 +5,7 @@ import math import os import random -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional import einops import numpy as np @@ -13,6 +13,7 @@ from tqdm import tqdm from PIL import Image import accelerate +from transformers import CLIPTextModel from library import device_utils from library.device_utils import init_ipex, get_preferred_device @@ -125,7 +126,7 @@ def do_sample( def generate_image( model, - clip_l, + clip_l: CLIPTextModel, t5xxl, ae, prompt: str, @@ -141,12 +142,13 @@ def generate_image( # make first noise with packed shape # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) + noise_dtype = torch.float32 if is_fp8(dtype) else dtype noise = torch.randn( 1, packed_latent_height * packed_latent_width, 16 * 2 * 2, device=device, - dtype=dtype, + dtype=noise_dtype, generator=torch.Generator(device=device).manual_seed(seed), ) @@ -166,9 +168,48 @@ def generate_image( clip_l = clip_l.to(device) t5xxl = t5xxl.to(device) with torch.no_grad(): - if is_fp8(clip_l_dtype) or is_fp8(t5xxl_dtype): - clip_l.to(clip_l_dtype) - t5xxl.to(t5xxl_dtype) + if is_fp8(clip_l_dtype): + param_itr = clip_l.parameters() + param_itr.__next__() # skip first + param_2nd = param_itr.__next__() + if param_2nd.dtype != clip_l_dtype: + logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") + clip_l.to(clip_l_dtype) # fp8 + clip_l.text_model.embeddings.to(dtype=torch.bfloat16) + + with accelerator.autocast(): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + + if is_fp8(t5xxl_dtype): + if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"): + logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + text_encoder.fp8_prepared = True + + t5xxl.to(t5xxl_dtype) + prepare_fp8(t5xxl.encoder, torch.bfloat16) + with accelerator.autocast(): _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask @@ -315,10 +356,10 @@ def is_fp8(dt): t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device) t5xxl.eval() - if is_fp8(clip_l_dtype): - clip_l = accelerator.prepare(clip_l) - if is_fp8(t5xxl_dtype): - t5xxl = accelerator.prepare(t5xxl) + # if is_fp8(clip_l_dtype): + # clip_l = accelerator.prepare(clip_l) + # if is_fp8(t5xxl_dtype): + # t5xxl = accelerator.prepare(t5xxl) t5xxl_max_length = 256 if is_schnell else 512 tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) @@ -329,14 +370,16 @@ def is_fp8(dt): model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype - if is_fp8(flux_dtype): - model = accelerator.prepare(model) + # if is_fp8(flux_dtype): + # model = accelerator.prepare(model) + # if args.offload: + # model = model.to("cpu") # AE ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device) ae.eval() - if is_fp8(ae_dtype): - ae = accelerator.prepare(ae) + # if is_fp8(ae_dtype): + # ae = accelerator.prepare(ae) # LoRA lora_models: List[lora_flux.LoRANetwork] = [] @@ -360,7 +403,7 @@ def is_fp8(dt): lora_model.to(device) lora_models.append(lora_model) - + if not args.interactive: generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) else: diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 295267beb..ab9ccc4d8 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -392,7 +392,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_dim[lora_name] = dim # logger.info(lora_name, value.size(), dim) - if train_t5xxl is None: + if train_t5xxl is None or train_t5xxl is False: train_t5xxl = "lora_te3" in lora_name if train_t5xxl is None: From 90ed2dfb526168b2e77b8d367e928d8cc44b4278 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Sep 2024 08:39:29 +0900 Subject: [PATCH 126/348] feat: Add support for merging CLIP-L and T5XXL LoRA models --- README.md | 22 ++++- networks/flux_merge_lora.py | 182 ++++++++++++++++++++++++++++-------- 2 files changed, 163 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index c0acfa1d2..fa81f6c0f 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 5, 2024: +The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. + Sep 4, 2024: - T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. See [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) for details. - In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching. @@ -276,7 +279,7 @@ CLIP-L LoRA is not supported. ### Merge LoRA to FLUX.1 checkpoint -`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ +`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint, CLIP-L or T5XXL models. __The script is experimental.__ ``` python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu @@ -284,13 +287,24 @@ python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`. -`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`): +CLIP-L and T5XXL LoRA are supported. `--clip_l` and `--clip_l_save_to` are for CLIP-L, `--t5xxl` and `--t5xxl_save_to` are for T5XXL. Sample command is below. + +``` +--clip_l clip_l.safetensors --clip_l_save_to merged_clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --t5xxl_save_to merged_t5xxl.safetensors +``` + +FLUX.1, CLIP-L, and T5XXL can be merged together or separately for memory efficiency. + +An experimental option `--mem_eff_load_save` is available. This option is for memory-efficient loading and saving. It may also speed up loading and saving. + +`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`, `float32` will consume more memory): - 'cpu' / 'cpu': Uses >50GB of RAM, but works on any machine. - 'cuda' / 'cpu': Uses 24GB of VRAM, but requires 30GB of RAM. -- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cuda' / 'cpu'. +- 'cpu' / 'cuda': Uses 4GB of VRAM, but requires 50GB of RAM, faster than 'cpu' / 'cpu' or 'cuda' / 'cpu'. +- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cpu' / 'cpu' or 'cuda' / 'cpu'. -In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`. +`--save_precision` is the precision to save the merged model. In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`. The script can merge multiple LoRA models. If you want to merge multiple LoRA models, specify `--concat` option to work the merged LoRA model properly. diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index 2e0d4c297..5e100a3ba 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -2,6 +2,7 @@ import math import os import time +from typing import Any, Dict, Union import torch from safetensors import safe_open @@ -34,11 +35,11 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False): +def save_to_file(file_name, state_dict: Dict[str, Union[Any, torch.Tensor]], dtype, metadata, mem_eff_save=False): if dtype is not None: logger.info(f"converting to {dtype}...") for key in tqdm(list(state_dict.keys())): - if type(state_dict[key]) == torch.Tensor: + if type(state_dict[key]) == torch.Tensor and state_dict[key].dtype.is_floating_point: state_dict[key] = state_dict[key].to(dtype) logger.info(f"saving to: {file_name}") @@ -49,26 +50,76 @@ def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False): def merge_to_flux_model( - loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False + loading_device, + working_device, + flux_path: str, + clip_l_path: str, + t5xxl_path: str, + models, + ratios, + merge_dtype, + save_dtype, + mem_eff_load_save=False, ): # create module map without loading state_dict - logger.info(f"loading keys from FLUX.1 model: {flux_model}") lora_name_to_module_key = {} - with safe_open(flux_model, framework="pt", device=loading_device) as flux_file: - keys = list(flux_file.keys()) - for key in keys: - if key.endswith(".weight"): - module_name = ".".join(key.split(".")[:-1]) - lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") - lora_name_to_module_key[lora_name] = key - + if flux_path is not None: + logger.info(f"loading keys from FLUX.1 model: {flux_path}") + with safe_open(flux_path, framework="pt", device=loading_device) as flux_file: + keys = list(flux_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") + lora_name_to_module_key[lora_name] = key + + lora_name_to_clip_l_key = {} + if clip_l_path is not None: + logger.info(f"loading keys from clip_l model: {clip_l_path}") + with safe_open(clip_l_path, framework="pt", device=loading_device) as clip_l_file: + keys = list(clip_l_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP + "_" + module_name.replace(".", "_") + lora_name_to_clip_l_key[lora_name] = key + + lora_name_to_t5xxl_key = {} + if t5xxl_path is not None: + logger.info(f"loading keys from t5xxl model: {t5xxl_path}") + with safe_open(t5xxl_path, framework="pt", device=loading_device) as t5xxl_file: + keys = list(t5xxl_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5 + "_" + module_name.replace(".", "_") + lora_name_to_t5xxl_key[lora_name] = key + + flux_state_dict = {} + clip_l_state_dict = {} + t5xxl_state_dict = {} if mem_eff_load_save: - flux_state_dict = {} - with MemoryEfficientSafeOpen(flux_model) as flux_file: - for key in tqdm(flux_file.keys()): - flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + if flux_path is not None: + with MemoryEfficientSafeOpen(flux_path) as flux_file: + for key in tqdm(flux_file.keys()): + flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + + if clip_l_path is not None: + with MemoryEfficientSafeOpen(clip_l_path) as clip_l_file: + for key in tqdm(clip_l_file.keys()): + clip_l_state_dict[key] = clip_l_file.get_tensor(key).to(loading_device) + + if t5xxl_path is not None: + with MemoryEfficientSafeOpen(t5xxl_path) as t5xxl_file: + for key in tqdm(t5xxl_file.keys()): + t5xxl_state_dict[key] = t5xxl_file.get_tensor(key).to(loading_device) else: - flux_state_dict = load_file(flux_model, device=loading_device) + if flux_path is not None: + flux_state_dict = load_file(flux_path, device=loading_device) + if clip_l_path is not None: + clip_l_state_dict = load_file(clip_l_path, device=loading_device) + if t5xxl_path is not None: + t5xxl_state_dict = load_file(t5xxl_path, device=loading_device) for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") @@ -81,8 +132,20 @@ def merge_to_flux_model( up_key = key.replace("lora_down", "lora_up") alpha_key = key[: key.index("lora_down")] + "alpha" - if lora_name not in lora_name_to_module_key: - logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.") + if lora_name in lora_name_to_module_key: + module_weight_key = lora_name_to_module_key[lora_name] + state_dict = flux_state_dict + elif lora_name in lora_name_to_clip_l_key: + module_weight_key = lora_name_to_clip_l_key[lora_name] + state_dict = clip_l_state_dict + elif lora_name in lora_name_to_t5xxl_key: + module_weight_key = lora_name_to_t5xxl_key[lora_name] + state_dict = t5xxl_state_dict + else: + logger.warning( + f"no module found for LoRA weight: {key}. Skipping..." + f"LoRAの重みに対応するモジュールが見つかりませんでした。スキップします。" + ) continue down_weight = lora_sd.pop(key) @@ -93,11 +156,7 @@ def merge_to_flux_model( scale = alpha / dim # W <- W + U * D - module_weight_key = lora_name_to_module_key[lora_name] - if module_weight_key not in flux_state_dict: - weight = flux_file.get_tensor(module_weight_key) - else: - weight = flux_state_dict[module_weight_key] + weight = state_dict[module_weight_key] weight = weight.to(working_device, merge_dtype) up_weight = up_weight.to(working_device, merge_dtype) @@ -121,7 +180,7 @@ def merge_to_flux_model( # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + ratio * conved * scale - flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + state_dict[module_weight_key] = weight.to(loading_device, save_dtype) del up_weight del down_weight del weight @@ -129,7 +188,7 @@ def merge_to_flux_model( if len(lora_sd) > 0: logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}") - return flux_state_dict + return flux_state_dict, clip_l_state_dict, t5xxl_state_dict def merge_to_flux_model_diffusers( @@ -508,17 +567,28 @@ def merge(args): if save_dtype is None: save_dtype = merge_dtype - dest_dir = os.path.dirname(args.save_to) + assert ( + args.save_to or args.clip_l_save_to or args.t5xxl_save_to + ), "save_to or clip_l_save_to or t5xxl_save_to must be specified / save_toまたはclip_l_save_toまたはt5xxl_save_toを指定してください" + dest_dir = os.path.dirname(args.save_to or args.clip_l_save_to or args.t5xxl_save_to) if not os.path.exists(dest_dir): logger.info(f"creating directory: {dest_dir}") os.makedirs(dest_dir) - if args.flux_model is not None: + if args.flux_model is not None or args.clip_l is not None or args.t5xxl is not None: if not args.diffusers: - state_dict = merge_to_flux_model( + assert (args.clip_l is None and args.clip_l_save_to is None) or ( + args.clip_l is not None and args.clip_l_save_to is not None + ), "clip_l_save_to must be specified if clip_l is specified / clip_lが指定されている場合はclip_l_save_toも指定してください" + assert (args.t5xxl is None and args.t5xxl_save_to is None) or ( + args.t5xxl is not None and args.t5xxl_save_to is not None + ), "t5xxl_save_to must be specified if t5xxl is specified / t5xxlが指定されている場合はt5xxl_save_toも指定してください" + flux_state_dict, clip_l_state_dict, t5xxl_state_dict = merge_to_flux_model( args.loading_device, args.working_device, args.flux_model, + args.clip_l, + args.t5xxl, args.models, args.ratios, merge_dtype, @@ -526,7 +596,10 @@ def merge(args): args.mem_eff_load_save, ) else: - state_dict = merge_to_flux_model_diffusers( + assert ( + args.clip_l is None and args.t5xxl is None + ), "clip_l and t5xxl are not supported with --diffusers / clip_l、t5xxlはDiffusersではサポートされていません" + flux_state_dict = merge_to_flux_model_diffusers( args.loading_device, args.working_device, args.flux_model, @@ -536,8 +609,10 @@ def merge(args): save_dtype, args.mem_eff_load_save, ) + clip_l_state_dict = None + t5xxl_state_dict = None - if args.no_metadata: + if args.no_metadata or (flux_state_dict is None or len(flux_state_dict) == 0): sai_metadata = None else: merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) @@ -546,15 +621,24 @@ def merge(args): None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev" ) - logger.info(f"saving FLUX model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, sai_metadata, args.mem_eff_load_save) + if flux_state_dict is not None and len(flux_state_dict) > 0: + logger.info(f"saving FLUX model to: {args.save_to}") + save_to_file(args.save_to, flux_state_dict, save_dtype, sai_metadata, args.mem_eff_load_save) + + if clip_l_state_dict is not None and len(clip_l_state_dict) > 0: + logger.info(f"saving clip_l model to: {args.clip_l_save_to}") + save_to_file(args.clip_l_save_to, clip_l_state_dict, save_dtype, None, args.mem_eff_load_save) + + if t5xxl_state_dict is not None and len(t5xxl_state_dict) > 0: + logger.info(f"saving t5xxl model to: {args.t5xxl_save_to}") + save_to_file(args.t5xxl_save_to, t5xxl_state_dict, save_dtype, None, args.mem_eff_load_save) else: - state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + flux_state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) logger.info("calculating hashes and creating metadata...") - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(flux_state_dict, metadata) metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash @@ -562,12 +646,12 @@ def merge(args): merged_from = sai_model_spec.build_merged_from(args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" + flux_state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" ) metadata.update(sai_metadata) logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, metadata) + save_to_file(args.save_to, flux_state_dict, save_dtype, metadata) def setup_parser() -> argparse.ArgumentParser: @@ -592,6 +676,18 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする", ) + parser.add_argument( + "--clip_l", + type=str, + default=None, + help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)", + ) + parser.add_argument( + "--t5xxl", + type=str, + default=None, + help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)", + ) parser.add_argument( "--mem_eff_load_save", action="store_true", @@ -617,6 +713,18 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="destination file name: safetensors file / 保存先のファイル名、safetensorsファイル", ) + parser.add_argument( + "--clip_l_save_to", + type=str, + default=None, + help="destination file name for clip_l: safetensors file / clip_lの保存先のファイル名、safetensorsファイル", + ) + parser.add_argument( + "--t5xxl_save_to", + type=str, + default=None, + help="destination file name for t5xxl: safetensors file / t5xxlの保存先のファイル名、safetensorsファイル", + ) parser.add_argument( "--models", type=str, From d9129522a6effea7077f18cdea0ee733a5ac7cb0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Sep 2024 12:20:07 +0900 Subject: [PATCH 127/348] set dtype before calling ae closes #1562 --- flux_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux_train.py b/flux_train.py index 32a36f036..0293b7be3 100644 --- a/flux_train.py +++ b/flux_train.py @@ -651,7 +651,7 @@ def optimizer_hook(parameter: torch.Tensor): else: with torch.no_grad(): # encode images to latents. images are [-1, 1] - latents = ae.encode(batch["images"]) + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): From 2889108d858880589d362e06e98eeadf4682476a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Sep 2024 20:58:33 +0900 Subject: [PATCH 128/348] feat: Add --cpu_offload_checkpointing option to LoRA training --- README.md | 7 +++++++ flux_train.py | 2 +- flux_train_network.py | 5 +++++ train_network.py | 12 +++++++++++- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index fa81f6c0f..e8a12089f 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,12 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 5, 2024 (update 1): + +Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. + Sep 5, 2024: + The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. Sep 4, 2024: @@ -72,6 +77,8 @@ The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_ --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` +`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. + We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. The trained LoRA model can be used with ComfyUI. diff --git a/flux_train.py b/flux_train.py index 0293b7be3..0edc83a9f 100644 --- a/flux_train.py +++ b/flux_train.py @@ -261,7 +261,7 @@ def train(args): ) if args.gradient_checkpointing: - flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing) + flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) flux.requires_grad_(True) diff --git a/flux_train_network.py b/flux_train_network.py index 2fc0f3234..a6e57eede 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -50,6 +50,11 @@ def assert_extra_args(self, args, train_dataset_group): if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + assert not args.split_mode or not args.cpu_offload_checkpointing, ( + "split_mode and cpu_offload_checkpointing cannot be used together" + " / split_modeとcpu_offload_checkpointingは同時に使用できません" + ) + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this def get_flux_model_name(self, args): diff --git a/train_network.py b/train_network.py index a68ccfcc4..ad97491df 100644 --- a/train_network.py +++ b/train_network.py @@ -451,7 +451,11 @@ def train(self, args): accelerator.print(f"load network weights from {args.network_weights}: {info}") if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() + if args.cpu_offload_checkpointing: + unet.enable_gradient_checkpointing(cpu_offload=True) + else: + unet.enable_gradient_checkpointing() + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): if flag: if t_enc.supports_gradient_checkpointing: @@ -1281,6 +1285,12 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported" + " / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)", + ) parser.add_argument( "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" ) From d29af146b8d4c4d028f8752657bd1349c8cd3509 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Sep 2024 23:01:15 +0900 Subject: [PATCH 129/348] add negative prompt for flux inference script --- README.md | 3 + flux_minimal_inference.py | 289 ++++++++++++++++++++++++++------------ 2 files changed, 206 insertions(+), 86 deletions(-) diff --git a/README.md b/README.md index 2f010f499..126516f95 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 9, 2024: +Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. + Sep 5, 2024 (update 1): Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 1c194e7c1..de607c52a 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -71,22 +71,57 @@ def denoise( timesteps: list[float], guidance: float = 4.0, t5_attn_mask: Optional[torch.Tensor] = None, + neg_txt: Optional[torch.Tensor] = None, + neg_vec: Optional[torch.Tensor] = None, + neg_t5_attn_mask: Optional[torch.Tensor] = None, + cfg_scale: Optional[float] = None, ): # this is ignored for schnell + logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}") guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + + # prepare classifier free guidance + if neg_txt is not None and neg_vec is not None: + b_img_ids = torch.cat([img_ids, img_ids], dim=0) + b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0) + b_txt = torch.cat([neg_txt, txt], dim=0) + b_vec = torch.cat([neg_vec, vec], dim=0) + if t5_attn_mask is not None and neg_t5_attn_mask is not None: + b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) + else: + b_t5_attn_mask = None + else: + b_img_ids = img_ids + b_txt_ids = txt_ids + b_txt = txt + b_vec = vec + b_t5_attn_mask = t5_attn_mask + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): - t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device) + + # classifier free guidance + if neg_txt is not None and neg_vec is not None: + b_img = torch.cat([img, img], dim=0) + else: + b_img = img + pred = model( - img=img, - img_ids=img_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, + img=b_img, + img_ids=b_img_ids, + txt=b_txt, + txt_ids=b_txt_ids, + y=b_vec, timesteps=t_vec, guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, + txt_attention_mask=b_t5_attn_mask, ) + # classifier free guidance + if neg_txt is not None and neg_vec is not None: + pred_uncond, pred = torch.chunk(pred, 2, dim=0) + pred = pred_uncond + cfg_scale * (pred - pred_uncond) + img = img + (t_prev - t_curr) * pred return img @@ -106,19 +141,48 @@ def do_sample( is_schnell: bool, device: torch.device, flux_dtype: torch.dtype, + neg_l_pooled: Optional[torch.Tensor] = None, + neg_t5_out: Optional[torch.Tensor] = None, + neg_t5_attn_mask: Optional[torch.Tensor] = None, + cfg_scale: Optional[float] = None, ): + logger.info(f"num_steps: {num_steps}") timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) # denoise initial noise if accelerator: with accelerator.autocast(), torch.no_grad(): x = denoise( - model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + model, + img, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps, + guidance, + t5_attn_mask, + neg_t5_out, + neg_l_pooled, + neg_t5_attn_mask, + cfg_scale, ) else: with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad(): x = denoise( - model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + model, + img, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps, + guidance, + t5_attn_mask, + neg_t5_out, + neg_l_pooled, + neg_t5_attn_mask, + cfg_scale, ) return x @@ -135,6 +199,8 @@ def generate_image( image_height: int, steps: Optional[int], guidance: float, + negative_prompt: Optional[str], + cfg_scale: float, ): seed = seed if seed is not None else random.randint(0, 2**32 - 1) logger.info(f"Seed: {seed}") @@ -162,65 +228,73 @@ def generate_image( # txt2img only needs img_ids img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) + # prepare fp8 models + if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared): + logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") + clip_l.to(clip_l_dtype) # fp8 + clip_l.text_model.embeddings.to(dtype=torch.bfloat16) + clip_l.fp8_prepared = True + + if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared): + logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + t5xxl.to(t5xxl_dtype) + prepare_fp8(t5xxl.encoder, torch.bfloat16) + t5xxl.fp8_prepared = True + # prepare embeddings logger.info("Encoding prompts...") - tokens_and_masks = tokenize_strategy.tokenize(prompt) clip_l = clip_l.to(device) t5xxl = t5xxl.to(device) - with torch.no_grad(): - if is_fp8(clip_l_dtype): - param_itr = clip_l.parameters() - param_itr.__next__() # skip first - param_2nd = param_itr.__next__() - if param_2nd.dtype != clip_l_dtype: - logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") - clip_l.to(clip_l_dtype) # fp8 - clip_l.text_model.embeddings.to(dtype=torch.bfloat16) - - with accelerator.autocast(): - l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) - if is_fp8(t5xxl_dtype): - if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"): - logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") - - def prepare_fp8(text_encoder, target_dtype): - def forward_hook(module): - def forward(hidden_states): - hidden_gelu = module.act(module.wi_0(hidden_states)) - hidden_linear = module.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = module.dropout(hidden_states) - - hidden_states = module.wo(hidden_states) - return hidden_states - - return forward - - for module in text_encoder.modules(): - if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: - # print("set", module.__class__.__name__, "to", target_dtype) - module.to(target_dtype) - if module.__class__.__name__ in ["T5DenseGatedActDense"]: - # print("set", module.__class__.__name__, "hooks") - module.forward = forward_hook(module) - - text_encoder.fp8_prepared = True - - t5xxl.to(t5xxl_dtype) - prepare_fp8(t5xxl.encoder, torch.bfloat16) - - with accelerator.autocast(): - _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask - ) - else: - with torch.autocast(device_type=device.type, dtype=clip_l_dtype): - l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) - with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): - _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( - tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask - ) + def encode(prpt: str): + tokens_and_masks = tokenize_strategy.tokenize(prpt) + with torch.no_grad(): + if is_fp8(clip_l_dtype): + with accelerator.autocast(): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + else: + with torch.autocast(device_type=device.type, dtype=clip_l_dtype): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + + if is_fp8(t5xxl_dtype): + with accelerator.autocast(): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + else: + with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + return l_pooled, t5_out, txt_ids, t5_attn_mask + + l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt) + if negative_prompt: + neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt) + else: + neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None # NaN check if torch.isnan(l_pooled).any(): @@ -244,7 +318,23 @@ def forward(hidden_states): t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None x = do_sample( - accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, t5_attn_mask, is_schnell, device, flux_dtype + accelerator, + model, + noise, + img_ids, + l_pooled, + t5_out, + txt_ids, + steps, + guidance, + t5_attn_mask, + is_schnell, + device, + flux_dtype, + neg_l_pooled, + neg_t5_out, + neg_t5_attn_mask, + cfg_scale, ) if args.offload: model = model.cpu() @@ -307,6 +397,8 @@ def forward(hidden_states): parser.add_argument("--seed", type=int, default=None) parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev") parser.add_argument("--guidance", type=float, default=3.5) + parser.add_argument("--negative_prompt", type=str, default=None) + parser.add_argument("--cfg_scale", type=float, default=1.0) parser.add_argument("--offload", action="store_true", help="Offload to CPU") parser.add_argument( "--lora_weights", @@ -403,19 +495,34 @@ def is_fp8(dt): lora_model.to(device) lora_models.append(lora_model) - + if not args.interactive: - generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) + generate_image( + model, + clip_l, + t5xxl, + ae, + args.prompt, + args.seed, + args.width, + args.height, + args.steps, + args.guidance, + args.negative_prompt, + args.cfg_scale, + ) else: # loop for interactive width = target_width height = target_height steps = None guidance = args.guidance + cfg_scale = args.cfg_scale while True: print( "Enter prompt (empty to exit). Options: --w --h --s --d --g --m " + " --n , `-` for empty negative prompt --c " ) prompt = input() if prompt == "": @@ -425,26 +532,36 @@ def is_fp8(dt): options = prompt.split("--") prompt = options[0].strip() seed = None + negative_prompt = None for opt in options[1:]: - opt = opt.strip() - if opt.startswith("w"): - width = int(opt[1:].strip()) - elif opt.startswith("h"): - height = int(opt[1:].strip()) - elif opt.startswith("s"): - steps = int(opt[1:].strip()) - elif opt.startswith("d"): - seed = int(opt[1:].strip()) - elif opt.startswith("g"): - guidance = float(opt[1:].strip()) - elif opt.startswith("m"): - mutipliers = opt[1:].strip().split(",") - if len(mutipliers) != len(lora_models): - logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") - continue - for i, lora_model in enumerate(lora_models): - lora_model.set_multiplier(float(mutipliers[i])) - - generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance) + try: + opt = opt.strip() + if opt.startswith("w"): + width = int(opt[1:].strip()) + elif opt.startswith("h"): + height = int(opt[1:].strip()) + elif opt.startswith("s"): + steps = int(opt[1:].strip()) + elif opt.startswith("d"): + seed = int(opt[1:].strip()) + elif opt.startswith("g"): + guidance = float(opt[1:].strip()) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) + elif opt.startswith("n"): + negative_prompt = opt[1:].strip() + if negative_prompt == "-": + negative_prompt = "" + elif opt.startswith("c"): + cfg_scale = float(opt[1:].strip()) + except ValueError as e: + logger.error(f"Invalid option: {opt}, {e}") + + generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale) logger.info("Done!") From d10ff62a78b15d0bb55f443cc2849c460300131b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 10 Sep 2024 20:32:09 +0900 Subject: [PATCH 130/348] support individual LR for CLIP-L/T5XXL --- README.md | 4 +++ networks/lora_flux.py | 71 +++++++++++++++---------------------------- train_network.py | 32 ++++++++++++------- 3 files changed, 49 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 126516f95..b5799dd6f 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 10, 2024: +In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. + Sep 9, 2024: Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. @@ -142,6 +145,7 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times - Remove `--network_train_unet_only` from your command. - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time. - T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. - The trained LoRA can be used with ComfyUI. - Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet. diff --git a/networks/lora_flux.py b/networks/lora_flux.py index ab9ccc4d8..d540c2215 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -786,28 +786,23 @@ def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, lorap logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") - # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): - # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) - # if ( - # self.loraplus_lr_ratio is not None - # or self.loraplus_text_encoder_lr_ratio is not None - # or self.loraplus_unet_lr_ratio is not None - # ): - # assert ( - # optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower() - # ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません" + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of two elements + if text_encoder_lr is None or len(text_encoder_lr) == 0: + text_encoder_lr = [default_lr, default_lr] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] self.requires_grad_(True) all_params = [] lr_descriptions = [] - def assemble_params(loras, lr, ratio): + def assemble_params(loras, lr, loraplus_ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: for name, param in lora.named_parameters(): - if ratio is not None and "lora_up" in name: + if loraplus_ratio is not None and "lora_up" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param @@ -822,7 +817,7 @@ def assemble_params(loras, lr, ratio): if lr is not None: if key == "plus": - param_data["lr"] = lr * ratio + param_data["lr"] = lr * loraplus_ratio else: param_data["lr"] = lr @@ -836,41 +831,23 @@ def assemble_params(loras, lr, ratio): return params, descriptions if self.text_encoder_loras: - params, descriptions = assemble_params( - self.text_encoder_loras, - text_encoder_lr if text_encoder_lr is not None else default_lr, - self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, - ) - all_params.extend(params) - lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions]) + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP)] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 2 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 2 " + (" " + d if d else "") for d in descriptions]) if self.unet_loras: - # if self.block_lr: - # is_sdxl = False - # for lora in self.unet_loras: - # if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name: - # is_sdxl = True - # break - - # # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 - # block_idx_to_lora = {} - # for lora in self.unet_loras: - # idx = get_block_index(lora.lora_name, is_sdxl) - # if idx not in block_idx_to_lora: - # block_idx_to_lora[idx] = [] - # block_idx_to_lora[idx].append(lora) - - # # blockごとにパラメータを設定する - # for idx, block_loras in block_idx_to_lora.items(): - # params, descriptions = assemble_params( - # block_loras, - # (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx), - # self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, - # ) - # all_params.extend(params) - # lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions]) - - # else: params, descriptions = assemble_params( self.unet_loras, unet_lr if unet_lr is not None else default_lr, diff --git a/train_network.py b/train_network.py index ad97491df..e45db0525 100644 --- a/train_network.py +++ b/train_network.py @@ -466,9 +466,17 @@ def train(self, args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - # 後方互換性を確保するよ + # make backward compatibility for text_encoder_lr + support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs") + if support_multiple_lrs: + text_encoder_lr = args.text_encoder_lr + else: + text_encoder_lr = None if args.text_encoder_lr is None or len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] try: - results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + if support_multiple_lrs: + results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate) + else: + results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate) if type(results) is tuple: trainable_params = results[0] lr_descriptions = results[1] @@ -476,11 +484,7 @@ def train(self, args): trainable_params = results lr_descriptions = None except TypeError as e: - # logger.warning(f"{e}") - # accelerator.print( - # "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" - # ) - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr) lr_descriptions = None # if len(trainable_params) == 0: @@ -713,7 +717,7 @@ def load_model_hook(models, input_dir): "ss_training_started_at": training_started_at, # unix timestamp "ss_output_name": args.output_name, "ss_learning_rate": args.learning_rate, - "ss_text_encoder_lr": args.text_encoder_lr, + "ss_text_encoder_lr": text_encoder_lr, "ss_unet_lr": args.unet_lr, "ss_num_train_images": train_dataset_group.num_train_images, "ss_num_reg_images": train_dataset_group.num_reg_images, @@ -760,8 +764,8 @@ def load_model_hook(models, input_dir): "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, "ss_huber_c": args.huber_c, - "ss_fp8_base": args.fp8_base, - "ss_fp8_base_unet": args.fp8_base_unet, + "ss_fp8_base": bool(args.fp8_base), + "ss_fp8_base_unet": bool(args.fp8_base_unet), } self.update_metadata(metadata, args) # architecture specific metadata @@ -1303,7 +1307,13 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") - parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument( + "--text_encoder_lr", + type=float, + default=None, + nargs="*", + help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能", + ) parser.add_argument( "--fp8_base_unet", action="store_true", From 65b8a064f6bb9a403374d4b08f4003037df42f8d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 10 Sep 2024 21:20:38 +0900 Subject: [PATCH 131/348] update README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b5799dd6f..caea59b7e 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The command to install PyTorch is as follows: ### Recent Updates Sep 10, 2024: -In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. +In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. Sep 9, 2024: Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. @@ -145,7 +145,7 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times - Remove `--network_train_unet_only` from your command. - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time. - T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. - - The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. + - The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. - The trained LoRA can be used with ComfyUI. - Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet. From 8311e88225fef377591e5be19eb1f50fe7a2941f Mon Sep 17 00:00:00 2001 From: cocktailpeanut Date: Wed, 11 Sep 2024 09:02:29 -0400 Subject: [PATCH 132/348] typo fix --- library/train_util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index c38864fe6..f682dcbfb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3355,15 +3355,14 @@ def int_or_float(value): type=int, default=None, help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" - " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", - , + + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", ) parser.add_argument( "--lr_scheduler_min_lr_ratio", type=float, default=None, help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" - " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", + + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", ) From a823fd9fb8d219b5b4c57df12eed41ae34fdf843 Mon Sep 17 00:00:00 2001 From: Plat <60182057+p1atdev@users.noreply.github.com> Date: Wed, 11 Sep 2024 22:21:16 +0900 Subject: [PATCH 133/348] Improve wandb logging (#1576) * fix: wrong training steps were recorded to wandb, and no log was sent when logging_dir was not specified * fix: checking of whether wandb is enabled * feat: log images to wandb with their positive prompt as captions * feat: logging sample images' caption for sd3 and flux * fix: import wandb before use --- fine_tune.py | 7 +++++-- flux_train.py | 7 +++++-- library/flux_train_utils.py | 20 +++++++++++--------- library/sd3_train_utils.py | 20 +++++++++++--------- library/train_util.py | 20 +++++++++++--------- sd3_train.py | 7 +++++-- sdxl_train.py | 7 +++++-- sdxl_train_control_net_lllite.py | 4 ++-- sdxl_train_control_net_lllite_old.py | 4 ++-- train_controlnet.py | 7 +++++-- train_db.py | 7 +++++-- train_network.py | 7 +++++-- train_textual_inversion.py | 8 ++++++-- train_textual_inversion_XTI.py | 4 ++-- 14 files changed, 80 insertions(+), 49 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index c9102f6c0..fb6b3ed69 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -337,6 +337,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): train_util.sample_images( accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -456,7 +459,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) accelerator.log(logs, step=global_step) @@ -469,7 +472,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/flux_train.py b/flux_train.py index 0edc83a9f..33481df8f 100644 --- a/flux_train.py +++ b/flux_train.py @@ -629,6 +629,9 @@ def optimizer_hook(parameter: torch.Tensor): # For --sample_at_first flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0 @@ -777,7 +780,7 @@ def optimizer_hook(parameter: torch.Tensor): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) @@ -791,7 +794,7 @@ def optimizer_hook(parameter: torch.Tensor): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 0b5d4d90e..f77d4b585 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -254,17 +254,19 @@ def sample_image_inference( img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename)) - # wandb有効時のみログを送信 - try: + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + import wandb + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log( + {f"sample_{i}": wandb.Image( + image, + caption=prompt # positive prompt as a caption + )}, + commit=False + ) def time_shift(mu: float, sigma: float, t: torch.Tensor): diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index da0729506..e819d440c 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -604,17 +604,19 @@ def sample_image_inference( img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename)) - # wandb有効時のみログを送信 - try: + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + import wandb + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log( + {f"sample_{i}": wandb.Image( + image, + caption=prompt # positive prompt as a caption + )}, + commit=False + ) # region Diffusers diff --git a/library/train_util.py b/library/train_util.py index f682dcbfb..742d057e0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5832,17 +5832,19 @@ def sample_image_inference( img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename)) - # wandb有効時のみログを送信 - try: + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + import wandb + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log( + {f"sample_{i}": wandb.Image( + image, + caption=prompt # positive prompt as a caption + )}, + commit=False + ) # endregion diff --git a/sd3_train.py b/sd3_train.py index 87011b215..5120105f2 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -682,6 +682,9 @@ def optimizer_hook(parameter: torch.Tensor): # For --sample_at_first sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # following function will be moved to sd3_train_utils @@ -901,7 +904,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_mmdit) @@ -915,7 +918,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/sdxl_train.py b/sdxl_train.py index b2c62dd11..7291ddd2f 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -617,6 +617,9 @@ def optimizer_hook(parameter: torch.Tensor): sdxl_train_util.sample_images( accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -797,7 +800,7 @@ def optimizer_hook(parameter: torch.Tensor): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} if block_lrs is None: train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet) @@ -814,7 +817,7 @@ def optimizer_hook(parameter: torch.Tensor): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 0eaec29b8..9d1cfc63e 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -541,14 +541,14 @@ def remove_model(old_ckpt_name): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 292a0463a..6fa1d6096 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -480,14 +480,14 @@ def remove_model(old_ckpt_name): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/train_controlnet.py b/train_controlnet.py index c9ac6c5a8..57f0d263f 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -409,6 +409,9 @@ def remove_model(old_ckpt_name): train_util.sample_images( accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # training loop for epoch in range(num_train_epochs): @@ -542,14 +545,14 @@ def remove_model(old_ckpt_name): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/train_db.py b/train_db.py index 7caee6647..d42afd89a 100644 --- a/train_db.py +++ b/train_db.py @@ -315,6 +315,9 @@ def train(args): train_util.sample_images( accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -445,7 +448,7 @@ def train(args): ) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) accelerator.log(logs, step=global_step) @@ -458,7 +461,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/train_network.py b/train_network.py index e45db0525..34385ae08 100644 --- a/train_network.py +++ b/train_network.py @@ -1038,6 +1038,9 @@ def remove_model(old_ckpt_name): # For --sample_at_first self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # training loop if initial_step > 0: # only if skip_until_initial_step is specified @@ -1224,7 +1227,7 @@ def remove_model(old_ckpt_name): if args.scale_weight_norms: progress_bar.set_postfix(**{**max_mean_logs, **logs}) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = self.generate_step_logs( args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm ) @@ -1233,7 +1236,7 @@ def remove_model(old_ckpt_name): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 9044f50df..956c78603 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -550,6 +550,9 @@ def remove_model(old_ckpt_name): unet, prompt_replacement, ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # training loop for epoch in range(num_train_epochs): @@ -684,7 +687,7 @@ def remove_model(old_ckpt_name): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} if ( args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() @@ -702,7 +705,7 @@ def remove_model(old_ckpt_name): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch + 1) @@ -739,6 +742,7 @@ def remove_model(old_ckpt_name): unet, prompt_replacement, ) + accelerator.log({}) # end of epoch diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index efb59137b..ca0b603fb 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -538,7 +538,7 @@ def remove_model(old_ckpt_name): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} if ( args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() @@ -556,7 +556,7 @@ def remove_model(old_ckpt_name): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch + 1) From 237317fffd060bcfb078b770ccd2df18bc4dd3a6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 11 Sep 2024 22:23:43 +0900 Subject: [PATCH 134/348] update README --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 2b3d0d5a8..d3481b6ae 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 11, 2024: +Logging to wandb is improved. See PR [#1576](https://github.com/kohya-ss/sd-scripts/pull/1576) for details. Thanks to p1atdev! + Sep 10, 2024: In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. From cefe52629e1901dd8192b0487afd5e9f089e3519 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 12 Sep 2024 12:36:07 +0900 Subject: [PATCH 135/348] fix to work old notation for TE LR in .toml --- networks/lora_flux.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index d540c2215..dd267de0f 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -788,8 +788,11 @@ def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, lorap def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): # make sure text_encoder_lr as list of two elements - if text_encoder_lr is None or len(text_encoder_lr) == 0: + # if float, use the same value for both text encoders + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): text_encoder_lr = [default_lr, default_lr] + elif isinstance(text_encoder_lr, float): + text_encoder_lr = [text_encoder_lr, text_encoder_lr] elif len(text_encoder_lr) == 1: text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] From 2d8ee3c28007393386528cfeec0a9b714dafd85b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 14 Sep 2024 15:48:16 +0900 Subject: [PATCH 136/348] OFT for FLUX.1 --- flux_minimal_inference.py | 20 +- networks/lora_flux.py | 6 +- networks/oft.py | 2 +- networks/oft_flux.py | 482 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 504 insertions(+), 6 deletions(-) create mode 100644 networks/oft_flux.py diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index de607c52a..2f1b9a377 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -14,9 +14,11 @@ from PIL import Image import accelerate from transformers import CLIPTextModel +from safetensors.torch import load_file from library import device_utils from library.device_utils import init_ipex, get_preferred_device +from networks import oft_flux init_ipex() @@ -405,7 +407,7 @@ def encode(prpt: str): type=str, nargs="*", default=[], - help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)", + help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)", ) parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") parser.add_argument("--width", type=int, default=target_width) @@ -482,9 +484,19 @@ def is_fp8(dt): else: multiplier = 1.0 - lora_model, weights_sd = lora_flux.create_network_from_weights( - multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True - ) + weights_sd = load_file(weights_file) + is_lora = is_oft = False + for key in weights_sd.keys(): + if key.startswith("lora"): + is_lora = True + if key.startswith("oft"): + is_oft = True + if is_lora or is_oft: + break + + module = lora_flux if is_lora else oft_flux + lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True) + if args.merge_lora_weights: lora_model.merge_to([clip_l, t5xxl], model, weights_sd) else: diff --git a/networks/lora_flux.py b/networks/lora_flux.py index dd267de0f..ea7df8b4d 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -41,7 +41,11 @@ def __init__( module_dropout=None, split_dims: Optional[List[int]] = None, ): - """if alpha == 0 or None, alpha is rank (no scaling).""" + """ + if alpha == 0 or None, alpha is rank (no scaling). + + split_dims is used to mimic the split qkv of FLUX as same as Diffusers + """ super().__init__() self.lora_name = lora_name diff --git a/networks/oft.py b/networks/oft.py index 6321def3b..0c3a5393f 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -51,7 +51,7 @@ def __init__( alpha = alpha.detach().numpy() # constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility - # original alpha is 1e-6, so we use 1e-3 or 1e-4 for alpha + # original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha self.constraint = alpha * out_dim self.register_buffer("alpha", torch.tensor(alpha)) diff --git a/networks/oft_flux.py b/networks/oft_flux.py new file mode 100644 index 000000000..27b8b637a --- /dev/null +++ b/networks/oft_flux.py @@ -0,0 +1,482 @@ +# OFT network module + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +import einops +from transformers import CLIPTextModel +import numpy as np +import torch +import torch.nn.functional as F +import re +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class OFTModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + split_dims: Optional[List[int]] = None, + ): + """ + dim -> num blocks + alpha -> constraint + + split_dims is used to mimic the split qkv of FLUX as same as Diffusers + """ + super().__init__() + self.oft_name = oft_name + self.num_blocks = dim + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().numpy() + self.register_buffer("alpha", torch.tensor(alpha)) + + # No conv2d in FLUX + # if "Linear" in org_module.__class__.__name__: + self.out_dim = org_module.out_features + # elif "Conv" in org_module.__class__.__name__: + # out_dim = org_module.out_channels + + if split_dims is None: + split_dims = [self.out_dim] + else: + assert sum(split_dims) == self.out_dim, "sum of split_dims must be equal to out_dim" + self.split_dims = split_dims + + # assert all dim is divisible by num_blocks + for split_dim in self.split_dims: + assert split_dim % self.num_blocks == 0, "split_dim must be divisible by num_blocks" + + self.constraint = [alpha * split_dim for split_dim in self.split_dims] + self.block_size = [split_dim // self.num_blocks for split_dim in self.split_dims] + self.oft_blocks = torch.nn.ParameterList( + [torch.nn.Parameter(torch.zeros(self.num_blocks, block_size, block_size)) for block_size in self.block_size] + ) + self.I = [torch.eye(block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) for block_size in self.block_size] + + self.shape = org_module.weight.shape + self.multiplier = multiplier + self.org_module = [org_module] # moduleにならないようにlistに入れる + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + if self.I[0].device != self.oft_blocks[0].device: + self.I = [I.to(self.oft_blocks[0].device) for I in self.I] + + block_R_weighted_list = [] + for i in range(len(self.oft_blocks)): + block_Q = self.oft_blocks[i] - self.oft_blocks[i].transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint[i]) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + + I = self.I[i] + block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + block_R_weighted = self.multiplier * (block_R - I) + I + + block_R_weighted_list.append(block_R_weighted) + + return block_R_weighted_list + + def forward(self, x, scale=None): + if self.multiplier == 0.0: + return self.org_forward(x) + + org_module = self.org_module[0] + org_dtype = x.dtype + + R = self.get_weight() + W = org_module.weight.to(torch.float32) + B = org_module.bias.to(torch.float32) + + # split W to match R + results = [] + d2 = 0 + for i in range(len(R)): + d1 = d2 + d2 += self.split_dims[i] + + W1 = W[d1:d2] + W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i]) + RW_1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped) + RW_1 = einops.rearrange(RW_1, "k m p -> (k m) p") + + B1 = B[d1:d2] + result = F.linear(x, RW_1.to(org_dtype), B1.to(org_dtype)) + results.append(result) + + result = torch.cat(results, dim=-1) + return result + + +class OFTInfModule(OFTModule): + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + split_dims: Optional[List[int]] = None, + **kwargs, + ): + # no dropout for inference + super().__init__(oft_name, org_module, multiplier, dim, alpha, split_dims) + self.enabled = True + self.network: OFTNetwork = None + + def set_network(self, network): + self.network = network + + def forward(self, x, scale=None): + if not self.enabled: + return self.org_forward(x) + return super().forward(x, scale) + + def merge_to(self, multiplier=None): + # get org weight + org_sd = self.org_module[0].state_dict() + W = org_sd["weight"].to(torch.float32) + R = self.get_weight(multiplier).to(torch.float32) + + d2 = 0 + W_list = [] + for i in range(len(self.oft_blocks)): + d1 = d2 + d2 += self.split_dims[i] + + W1 = W[d1:d2] + W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i]) + W1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped) + W1 = einops.rearrange(W1, "k m p -> (k m) p") + + W_list.append(W1) + + W = torch.cat(W_list, dim=-1) + + # convert back to original dtype + W = W.to(org_sd["weight"].dtype) + + # set weight to org_module + org_sd["weight"] = W + self.org_module[0].load_state_dict(org_sd) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: # should be set + logger.info( + "network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します" + ) + network_alpha = 1e-3 + elif network_alpha >= 1: + logger.warning( + "network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3" + " / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨" + ) + + # attn only or all linear (FFN) layers + enable_all_linear = kwargs.get("enable_all_linear", None) + # enable_conv = kwargs.get("enable_conv", None) + if enable_all_linear is not None: + enable_all_linear = bool(enable_all_linear) + # if enable_conv is not None: + # enable_conv = bool(enable_conv) + + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=network_dim, + alpha=network_alpha, + enable_all_linear=enable_all_linear, + varbose=True, + ) + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # check dim, alpha and if weights have for conv2d + dim = None + alpha = None + all_linear = None + for name, param in weights_sd.items(): + if name.endswith(".alpha"): + if alpha is None: + alpha = param.item() + elif "qkv" in name: + continue # ignore qkv + else: + if dim is None: + dim = param.size()[0] + if all_linear is None and "_mlp" in name: + all_linear = True + if dim is not None and alpha is not None and all_linear is not None: + break + if all_linear is None: + all_linear = False + + module_class = OFTInfModule if for_inference else OFTModule + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=dim, + alpha=alpha, + enable_all_linear=all_linear, + module_class=module_class, + ) + return network, weights_sd + + +class OFTNetwork(torch.nn.Module): + FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR = ["DoubleStreamBlock", "SingleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY = ["SelfAttention"] + OFT_PREFIX_UNET = "oft_unet" + + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + dim: int = 4, + alpha: float = 1, + enable_all_linear: Optional[bool] = False, + module_class: Union[Type[OFTModule], Type[OFTInfModule]] = OFTModule, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.train_t5xxl = False # make compatible with LoRA + self.multiplier = multiplier + + self.dim = dim + self.alpha = alpha + + logger.info( + f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_all_linear: {enable_all_linear}" + ) + + # create module instances + def create_modules( + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[OFTModule]: + prefix = self.OFT_PREFIX_UNET + ofts = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = "Linear" in child_module.__class__.__name__ + + if is_linear: + oft_name = prefix + "." + name + "." + child_name + oft_name = oft_name.replace(".", "_") + # logger.info(oft_name) + + if "double" in oft_name and "qkv" in oft_name: + split_dims = [3072] * 3 + elif "single" in oft_name and "linear1" in oft_name: + split_dims = [3072] * 3 + [12288] + else: + split_dims = None + + oft = module_class(oft_name, child_module, self.multiplier, dim, alpha, split_dims) + ofts.append(oft) + return ofts + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + if enable_all_linear: + target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR + else: + target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY + + self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) + logger.info(f"create OFT for Flux: {len(self.unet_ofts)} modules.") + + # assertion + names = set() + for oft in self.unet_ofts: + assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}" + names.add(oft.oft_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for oft in self.unet_ofts: + oft.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + assert apply_unet, "apply_unet must be True" + + for oft in self.unet_ofts: + oft.apply_to() + self.add_module(oft.oft_name, oft) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + logger.info("enable OFT for U-Net") + + for oft in self.unet_ofts: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(oft.oft_name): + sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key] + oft.load_state_dict(sd_for_lora, False) + oft.merge_to() + + logger.info(f"weights are merged") + + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(ofts): + params = [] + for oft in ofts: + params.extend(oft.parameters()) + + # logger.info num of params + num_params = 0 + for p in params: + num_params += p.numel() + logger.info(f"OFT params: {num_params}") + return params + + param_data = {"params": enumerate_params(self.unet_ofts)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + oft.merge_to() + # sd = org_module.state_dict() + # org_weight = sd["weight"] + # lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype) + # sd["weight"] = org_weight + lora_weight + # assert sd["weight"].shape == org_weight.shape + # org_module.load_state_dict(sd) + + org_module._lora_restored = False + oft.enabled = False From c9ff4de90597e933b441502d45c175fe46b99714 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 14 Sep 2024 22:17:52 +0900 Subject: [PATCH 137/348] Add support for specifying rank for each layer in FLUX.1 --- README.md | 61 ++++++++++++++++++++++++ networks/lora_flux.py | 107 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 161 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 6e32fa31d..9a9794796 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 14, 2024: +- You can now specify the rank for each layer in FLUX.1. See [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) for details. +- OFT is now supported with FLUX.1. See [FLUX.1 OFT training](#flux1-oft-training) for details. + Sep 11, 2024: Logging to wandb is improved. See PR [#1576](https://github.com/kohya-ss/sd-scripts/pull/1576) for details. Thanks to p1atdev! @@ -46,6 +50,7 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. ` - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) - [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) +- [FLUX.1 OFT training](#flux1-oft-training) - [FLUX.1 fine-tuning](#flux1-fine-tuning) - [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning) - [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models) @@ -191,6 +196,62 @@ In the implementation of Black Forest Labs' model, the projection layers of q/k/ The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. +#### Specify rank for each layer in FLUX.1 + +You can specify the rank for each layer in FLUX.1 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer. + +When network_args is not specified, the default value (`network_dim`) is applied, same as before. + +|network_args|target layer| +|---|---| +|img_attn_dim|img_attn in DoubleStreamBlock| +|txt_attn_dim|txt_attn in DoubleStreamBlock| +|img_mlp_dim|img_mlp in DoubleStreamBlock| +|txt_mlp_dim|txt_mlp in DoubleStreamBlock| +|img_mod_dim|img_mod in DoubleStreamBlock| +|txt_mod_dim|txt_mod in DoubleStreamBlock| +|single_dim|linear1 and linear2 in SingleStreamBlock| +|single_mod_dim|modulation in SingleStreamBlock| + +example: +``` +--network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2" +"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" +``` + +You can apply LoRA to the conditioning layers of Flux by specifying `in_dims` in network_args. When specifying, be sure to specify 5 numbers in `[]` as a comma-separated list. + +example: +``` +--network_args "in_dims=[4,2,2,2,4]" +``` + +Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt_in`. The above example applies LoRA to all conditioning layers, with rank 4 for `img_in`, 2 for `time_in`, `vector_in`, `guidance_in`, and 4 for `txt_in`. + +If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`. + +### FLUX.1 OFT training + +You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different. + +- Change `--network_module` from `networks.lora_flux` to `networks.oft_flux`. +- `--network_dim` is the number of OFT blocks. Unlike LoRA rank, the smaller the dim, the larger the model. We recommend about 64 or 128. Please make the output dimension of the target layer of OFT divisible by the value of `--network_dim` (an error will occur if it is not divisible). Valid values are 64, 128, 256, 512, 1024, etc. +- `--network_alpha` is treated as a constraint for OFT. We recommend about 1e-2 to 1e-4. The default value when omitted is 1, which is too large, so be sure to specify it. +- CLIP/T5XXL is not supported. Specify `--network_train_unet_only`. +- `--network_args` specifies the hyperparameters of OFT. The following are valid: + - Specify `enable_all_linear=True` to target all linear connections in the MLP layer. The default is False, which targets only attention. + +Currently, there is no environment to infer FLUX.1 OFT. Inference is only possible with `flux_minimal_inference.py` (specify OFT model with `--lora`). + +Sample command is below. It will work with 24GB VRAM GPUs with the batch size of 1. + +``` +--network_module networks.oft_flux --network_dim 128 --network_alpha 1e-3 +--network_args "enable_all_linear=True" --learning_rate 1e-5 +``` + +The training can be done with 16GB VRAM GPUs without `--enable_all_linear` option and with Adafactor optimizer. + ### Inference for FLUX.1 with LoRA model The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. diff --git a/networks/lora_flux.py b/networks/lora_flux.py index ea7df8b4d..a34cde1a8 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -316,6 +316,44 @@ def create_network( else: conv_alpha = float(conv_alpha) + # attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv + img_attn_dim = kwargs.get("img_attn_dim", None) + txt_attn_dim = kwargs.get("txt_attn_dim", None) + img_mlp_dim = kwargs.get("img_mlp_dim", None) + txt_mlp_dim = kwargs.get("txt_mlp_dim", None) + img_mod_dim = kwargs.get("img_mod_dim", None) + txt_mod_dim = kwargs.get("txt_mod_dim", None) + single_dim = kwargs.get("single_dim", None) # SingleStreamBlock + single_mod_dim = kwargs.get("single_mod_dim", None) # SingleStreamBlock + if img_attn_dim is not None: + img_attn_dim = int(img_attn_dim) + if txt_attn_dim is not None: + txt_attn_dim = int(txt_attn_dim) + if img_mlp_dim is not None: + img_mlp_dim = int(img_mlp_dim) + if txt_mlp_dim is not None: + txt_mlp_dim = int(txt_mlp_dim) + if img_mod_dim is not None: + img_mod_dim = int(img_mod_dim) + if txt_mod_dim is not None: + txt_mod_dim = int(txt_mod_dim) + if single_dim is not None: + single_dim = int(single_dim) + if single_mod_dim is not None: + single_mod_dim = int(single_mod_dim) + type_dims = [img_attn_dim, txt_attn_dim, img_mlp_dim, txt_mlp_dim, img_mod_dim, txt_mod_dim, single_dim, single_mod_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # in_dims [img, time, vector, guidance, txt] + in_dims = kwargs.get("in_dims", None) + if in_dims is not None: + in_dims = in_dims.strip() + if in_dims.startswith("[") and in_dims.endswith("]"): + in_dims = in_dims[1:-1] + in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval? + assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)" + # rank/module dropout rank_dropout = kwargs.get("rank_dropout", None) if rank_dropout is not None: @@ -339,6 +377,11 @@ def create_network( if train_t5xxl is not None: train_t5xxl = True if train_t5xxl == "True" else False + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -354,7 +397,9 @@ def create_network( train_blocks=train_blocks, split_qkv=split_qkv, train_t5xxl=train_t5xxl, - varbose=True, + type_dims=type_dims, + in_dims=in_dims, + verbose=verbose, ) loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) @@ -462,7 +507,9 @@ def __init__( train_blocks: Optional[str] = None, split_qkv: bool = False, train_t5xxl: bool = False, - varbose: Optional[bool] = False, + type_dims: Optional[List[int]] = None, + in_dims: Optional[List[int]] = None, + verbose: Optional[bool] = False, ) -> None: super().__init__() self.multiplier = multiplier @@ -478,12 +525,17 @@ def __init__( self.split_qkv = split_qkv self.train_t5xxl = train_t5xxl + self.type_dims = type_dims + self.in_dims = in_dims + self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None self.loraplus_text_encoder_lr_ratio = None if modules_dim is not None: logger.info(f"create LoRA network from weights") + self.in_dims = [0] * 5 # create in_dims + # verbose = True else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") logger.info( @@ -502,7 +554,12 @@ def __init__( # create module instances def create_modules( - is_flux: bool, text_encoder_idx: Optional[int], root_module: torch.nn.Module, target_replace_modules: List[str] + is_flux: bool, + text_encoder_idx: Optional[int], + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, ) -> List[LoRAModule]: prefix = ( self.LORA_PREFIX_FLUX @@ -513,16 +570,22 @@ def create_modules( loras = [] skipped = [] for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # dirty hack for all modules + module = root_module # search all modules + for child_name, child_module in module.named_modules(): is_linear = child_module.__class__.__name__ == "Linear" is_conv2d = child_module.__class__.__name__ == "Conv2d" is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) if is_linear or is_conv2d: - lora_name = prefix + "." + name + "." + child_name + lora_name = prefix + "." + (name + "." if name else "") + child_name lora_name = lora_name.replace(".", "_") + if filter is not None and not filter in lora_name: + continue + dim = None alpha = None @@ -534,8 +597,25 @@ def create_modules( else: # 通常、すべて対象とする if is_linear or is_conv2d_1x1: - dim = self.lora_dim + dim = default_dim if default_dim is not None else self.lora_dim alpha = self.alpha + + if type_dims is not None: + identifier = [ + ("img_attn",), + ("txt_attn",), + ("img_mlp",), + ("txt_mlp",), + ("img_mod",), + ("txt_mod",), + ("single_blocks", "linear"), + ("modulation",), + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d + break + elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha @@ -566,6 +646,9 @@ def create_modules( split_dims=split_dims, ) loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched return loras, skipped # create LoRA for text encoder @@ -594,10 +677,20 @@ def create_modules( self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules) + + # img, time, vector, guidance, txt + if self.in_dims: + for filter, in_dim in zip(["_img_in", "_time_in", "_vector_in", "_guidance_in", "_txt_in"], self.in_dims): + loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim) + self.unet_loras.extend(loras) + logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") skipped = skipped_te + skipped_un - if varbose and len(skipped) > 0: + if verbose and len(skipped) > 0: logger.warning( f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) From 6445bb2bc974cec51256ae38c1be0900e90e6f87 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 14 Sep 2024 22:37:26 +0900 Subject: [PATCH 138/348] update README --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9a9794796..c94ea3598 100644 --- a/README.md +++ b/README.md @@ -213,10 +213,12 @@ When network_args is not specified, the default value (`network_dim`) is applied |single_dim|linear1 and linear2 in SingleStreamBlock| |single_mod_dim|modulation in SingleStreamBlock| +`"verbose=True"` is also available for debugging. It shows the rank of each layer. + example: ``` --network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2" -"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" +"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" "verbose=True" ``` You can apply LoRA to the conditioning layers of Flux by specifying `in_dims` in network_args. When specifying, be sure to specify 5 numbers in `[]` as a comma-separated list. From 9f44ef133083c530874c6cf022a4de8fda3edae2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 15 Sep 2024 13:52:23 +0900 Subject: [PATCH 139/348] add diffusers to FLUX.1 conversion script --- README.md | 19 ++- tools/convert_diffusers_to_flux.py | 223 +++++++++++++++++++++++++++++ 2 files changed, 241 insertions(+), 1 deletion(-) create mode 100644 tools/convert_diffusers_to_flux.py diff --git a/README.md b/README.md index c94ea3598..7d6c336e6 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,12 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 15, 2024: + +Added a script `convert_diffusers_to_flux.py` to convert Diffusers format FLUX.1 models (checkpoints) to BFL format. See `--help` for usage. Only Flux models are supported. AE/CLIP/T5XXL are not supported. + +The implementation is based on 2kpr's code. Thanks to 2kpr! + Sep 14, 2024: - You can now specify the rank for each layer in FLUX.1. See [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) for details. - OFT is now supported with FLUX.1. See [FLUX.1 OFT training](#flux1-oft-training) for details. @@ -57,6 +63,7 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. ` - [Convert FLUX LoRA](#convert-flux-lora) - [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) - [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) +- [Convert Diffusers to FLUX.1](#convert-diffusers-to-flux1) ### FLUX.1 LoRA training @@ -355,7 +362,7 @@ If you use LoRA in the inference environment, converting it to AI-toolkit format Note that re-conversion will increase the size of LoRA. -CLIP-L LoRA is not supported. +CLIP-L/T5XXL LoRA is not supported. ### Merge LoRA to FLUX.1 checkpoint @@ -435,6 +442,16 @@ resolution = [512, 512] num_repeats = 1 ``` +### Convert Diffusers to FLUX.1 + +Script: `convert_diffusers_to_flux1.py` + +Converts Diffusers models to FLUX.1 models. The script is experimental. See `--help` for options. schnell and dev models are supported. AE/CLIP/T5XXL are not supported. The diffusers folder is a parent folder of `transfomer` folder. + +``` +python tools/convert_diffusers_to_flux.py --diffusers_path path/to/diffusers_folder_or_00001_safetensors --save_to path/to/flux1.safetensors --mem_eff_load_save --save_precision bf16 +``` + ## SD3 training SD3 training is done with `sd3_train.py`. diff --git a/tools/convert_diffusers_to_flux.py b/tools/convert_diffusers_to_flux.py new file mode 100644 index 000000000..9d8f7c74b --- /dev/null +++ b/tools/convert_diffusers_to_flux.py @@ -0,0 +1,223 @@ +# This script converts the diffusers of a Flux model to a safetensors file of a Flux.1 model. +# It is based on the implementation by 2kpr. Thanks to 2kpr! +# Major changes: +# - Iterates over three safetensors files to reduce memory usage, not loading all tensors at once. +# - Makes reverse map from diffusers map to avoid loading all tensors. +# - Removes dependency on .json file for weights mapping. +# - Adds support for custom memory efficient load and save functions. +# - Supports saving with different precision. +# - Supports .safetensors file as input. + +# Copyright 2024 2kpr. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import os +from pathlib import Path +import safetensors +from safetensors.torch import safe_open +import torch +from tqdm import tqdm + +from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + +BFL_TO_DIFFUSERS_MAP = { + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +def convert(args): + # if diffusers_path is folder, get safetensors file + diffusers_path = Path(args.diffusers_path) + if diffusers_path.is_dir(): + diffusers_path = Path.joinpath(diffusers_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors") + + flux_path = Path(args.save_to) + if not os.path.exists(flux_path.parent): + os.makedirs(flux_path.parent) + + if not diffusers_path.exists(): + logger.error(f"Error: Missing transformer safetensors file: {diffusers_path}") + return + + mem_eff_flag = args.mem_eff_load_save + save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None + + # make reverse map from diffusers map + diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) + for b in range(NUM_DOUBLE_BLOCKS): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("double_blocks."): + block_prefix = f"transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for b in range(NUM_SINGLE_BLOCKS): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("single_blocks."): + block_prefix = f"single_transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): + for i, weight in enumerate(weights): + diffusers_to_bfl_map[weight] = (i, key) + + # iterate over three safetensors files to reduce memory usage + flux_sd = {} + for i in range(3): + # replace 00001 with 0000i + current_diffusers_path = Path(str(diffusers_path).replace("00001", f"0000{i+1}")) + logger.info(f"Loading diffusers file: {current_diffusers_path}") + + open_func = MemoryEfficientSafeOpen if mem_eff_flag else (lambda x: safe_open(x, framework="pt")) + with open_func(current_diffusers_path) as f: + for diffusers_key in tqdm(f.keys()): + if diffusers_key in diffusers_to_bfl_map: + tensor = f.get_tensor(diffusers_key).to("cpu") + if save_dtype is not None: + tensor = tensor.to(save_dtype) + + index, bfl_key = diffusers_to_bfl_map[diffusers_key] + if bfl_key not in flux_sd: + flux_sd[bfl_key] = [] + flux_sd[bfl_key].append((index, tensor)) + else: + logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}") + return + + # concat tensors if multiple tensors are mapped to a single key, sort by index + for key, values in flux_sd.items(): + if len(values) == 1: + flux_sd[key] = values[0][1] + else: + flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])]) + + # special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + if "final_layer.adaLN_modulation.1.weight" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"]) + if "final_layer.adaLN_modulation.1.bias" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"]) + + # save flux_sd to safetensors file + logger.info(f"Saving Flux safetensors file: {flux_path}") + if mem_eff_flag: + mem_eff_save_file(flux_sd, flux_path) + else: + safetensors.torch.save_file(flux_sd, flux_path) + + logger.info("Conversion completed.") + + +def setup_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--diffusers_path", + default=None, + type=str, + required=True, + help="Path to the original Flux diffusers folder or *-00001-of-00003.safetensors file." + " / 元のFlux diffusersフォルダーまたは*-00001-of-00003.safetensorsファイルへのパス", + ) + parser.add_argument( + "--save_to", + default=None, + type=str, + required=True, + help="Output path for the Flux safetensors file. / Flux safetensorsファイルの出力先", + ) + parser.add_argument( + "--mem_eff_load_save", + action="store_true", + help="use custom memory efficient load and save functions for FLUX.1 model" + " / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する", + ) + parser.add_argument( + "--save_precision", + type=str, + default=None, + help="precision in saving, default is same as loading precision" + "float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz" + " / 保存時に精度を変更して保存する、デフォルトは読み込み時と同じ精度", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + args = parser.parse_args() + convert(args) From be078bdaca41084a20edb952b98a82f3e05d2dad Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 15 Sep 2024 13:59:17 +0900 Subject: [PATCH 140/348] fix typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7d6c336e6..f79fe21af 100644 --- a/README.md +++ b/README.md @@ -446,7 +446,7 @@ resolution = [512, 512] Script: `convert_diffusers_to_flux1.py` -Converts Diffusers models to FLUX.1 models. The script is experimental. See `--help` for options. schnell and dev models are supported. AE/CLIP/T5XXL are not supported. The diffusers folder is a parent folder of `transfomer` folder. +Converts Diffusers models to FLUX.1 models. The script is experimental. See `--help` for options. schnell and dev models are supported. AE/CLIP/T5XXL are not supported. The diffusers folder is a parent folder of `rmer` folder. ``` python tools/convert_diffusers_to_flux.py --diffusers_path path/to/diffusers_folder_or_00001_safetensors --save_to path/to/flux1.safetensors --mem_eff_load_save --save_precision bf16 From 96c677b4594ed6f28f3ef896f6deca7c3aced25d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 16 Sep 2024 10:42:09 +0900 Subject: [PATCH 141/348] fix to work lienar/cosine lr scheduler closes #1602 ref #1393 --- library/train_util.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 742d057e0..60afd4219 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4707,6 +4707,15 @@ def wrap_check_needless_num_warmup_steps(return_vals): **lr_scheduler_kwargs, ) + # these schedulers do not require `num_decay_steps` + if name == SchedulerType.LINEAR or name == SchedulerType.COSINE: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **lr_scheduler_kwargs, + ) + # All other schedulers require `num_decay_steps` if num_decay_steps is None: raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") @@ -5837,14 +5846,9 @@ def sample_image_inference( wandb_tracker = accelerator.get_tracker("wandb") import wandb + # not to commit images to avoid inconsistency between training and logging steps - wandb_tracker.log( - {f"sample_{i}": wandb.Image( - image, - caption=prompt # positive prompt as a caption - )}, - commit=False - ) + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption # endregion From d8d15f1a7e09ca217930288b41bd239881126b93 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 16 Sep 2024 23:14:09 +0900 Subject: [PATCH 142/348] add support for specifying blocks in FLUX.1 LoRA training --- README.md | 24 ++++++++++++- networks/lora_flux.py | 82 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 103 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f79fe21af..24217d8b7 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 16, 2024: + + Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details. + Sep 15, 2024: Added a script `convert_diffusers_to_flux.py` to convert Diffusers format FLUX.1 models (checkpoints) to BFL format. See `--help` for usage. Only Flux models are supported. AE/CLIP/T5XXL are not supported. @@ -54,9 +58,12 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. ` - [FLUX.1 LoRA training](#flux1-lora-training) - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) - - [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) + - [Distribution of timesteps](#distribution-of-timesteps) - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) + - [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) + - [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) - [FLUX.1 OFT training](#flux1-oft-training) +- [Inference for FLUX.1 with LoRA model](#inference-for-flux1-with-lora-model) - [FLUX.1 fine-tuning](#flux1-fine-tuning) - [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning) - [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models) @@ -239,6 +246,21 @@ Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`. +#### Specify blocks to train in FLUX.1 LoRA training + +You can specify the blocks to train in FLUX.1 LoRA training by specifying `train_double_block_indices` and `train_single_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`. The number of double blocks is 19, and the number of single blocks is 38, so the valid range is 0-18 and 0-37, respectively. `all` is also available to train all blocks, `none` is also available to train no blocks. + +example: +``` +--network_args "train_double_block_indices=0,1,8-12,18" "train_single_block_indices=3,10,20-25,37" +``` + +``` +--network_args "train_double_block_indices=none" "train_single_block_indices=10-15" +``` + +If you specify one of `train_double_block_indices` or `train_single_block_indices`, the other will be trained as usual. + ### FLUX.1 OFT training You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different. diff --git a/networks/lora_flux.py b/networks/lora_flux.py index a34cde1a8..f549ac18f 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -24,6 +24,10 @@ logger = logging.getLogger(__name__) +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + + class LoRAModule(torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. @@ -354,6 +358,50 @@ def create_network( in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval? assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)" + # double/single train blocks + def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: + """ + Parse a block selection string and return a list of booleans. + + Args: + selection (str): A string specifying which blocks to select. + total_blocks (int): The total number of blocks available. + + Returns: + List[bool]: A list of booleans indicating which blocks are selected. + """ + if selection == "all": + return [True] * total_blocks + if selection == "none" or selection == "": + return [False] * total_blocks + + selected = [False] * total_blocks + ranges = selection.split(",") + + for r in ranges: + if "-" in r: + start, end = map(str.strip, r.split("-")) + start = int(start) + end = int(end) + assert 0 <= start < total_blocks, f"invalid start index: {start}" + assert 0 <= end < total_blocks, f"invalid end index: {end}" + assert start <= end, f"invalid range: {start}-{end}" + for i in range(start, end + 1): + selected[i] = True + else: + index = int(r) + assert 0 <= index < total_blocks, f"invalid index: {index}" + selected[index] = True + + return selected + + train_double_block_indices = kwargs.get("train_double_block_indices", None) + train_single_block_indices = kwargs.get("train_single_block_indices", None) + if train_double_block_indices is not None: + train_double_block_indices = parse_block_selection(train_double_block_indices, NUM_DOUBLE_BLOCKS) + if train_single_block_indices is not None: + train_single_block_indices = parse_block_selection(train_single_block_indices, NUM_SINGLE_BLOCKS) + # rank/module dropout rank_dropout = kwargs.get("rank_dropout", None) if rank_dropout is not None: @@ -399,6 +447,8 @@ def create_network( train_t5xxl=train_t5xxl, type_dims=type_dims, in_dims=in_dims, + train_double_block_indices=train_double_block_indices, + train_single_block_indices=train_single_block_indices, verbose=verbose, ) @@ -509,6 +559,8 @@ def __init__( train_t5xxl: bool = False, type_dims: Optional[List[int]] = None, in_dims: Optional[List[int]] = None, + train_double_block_indices: Optional[List[bool]] = None, + train_single_block_indices: Optional[List[bool]] = None, verbose: Optional[bool] = False, ) -> None: super().__init__() @@ -527,6 +579,8 @@ def __init__( self.type_dims = type_dims self.in_dims = in_dims + self.train_double_block_indices = train_double_block_indices + self.train_single_block_indices = train_single_block_indices self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -600,7 +654,7 @@ def create_modules( dim = default_dim if default_dim is not None else self.lora_dim alpha = self.alpha - if type_dims is not None: + if is_flux and type_dims is not None: identifier = [ ("img_attn",), ("txt_attn",), @@ -613,9 +667,33 @@ def create_modules( ] for i, d in enumerate(type_dims): if d is not None and all([id in lora_name for id in identifier[i]]): - dim = d + dim = d # may be 0 for skip break + if ( + is_flux + and dim + and ( + self.train_double_block_indices is not None + or self.train_single_block_indices is not None + ) + and ("double" in lora_name or "single" in lora_name) + ): + # "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..." + block_index = int(lora_name.split("_")[4]) # bit dirty + if ( + "double" in lora_name + and self.train_double_block_indices is not None + and not self.train_double_block_indices[block_index] + ): + dim = 0 + elif ( + "single" in lora_name + and self.train_single_block_indices is not None + and not self.train_single_block_indices[block_index] + ): + dim = 0 + elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha From 0cbe95bcc7e88f518802f29fe2b99da806963267 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 17 Sep 2024 21:21:28 +0900 Subject: [PATCH 143/348] fix text_encoder_lr to work with int closes #1608 --- networks/lora_flux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index f549ac18f..91e9cd77f 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -966,8 +966,8 @@ def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr # if float, use the same value for both text encoders if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): text_encoder_lr = [default_lr, default_lr] - elif isinstance(text_encoder_lr, float): - text_encoder_lr = [text_encoder_lr, text_encoder_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr)] elif len(text_encoder_lr) == 1: text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] From a2ad7e5644f08141fe053a2b63446d70d777bdcf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 17 Sep 2024 21:42:14 +0900 Subject: [PATCH 144/348] blocks_to_swap=0 means no swap --- flux_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux_train.py b/flux_train.py index 33481df8f..5d8326b1d 100644 --- a/flux_train.py +++ b/flux_train.py @@ -265,7 +265,7 @@ def train(args): flux.requires_grad_(True) - is_swapping_blocks = args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None + is_swapping_blocks = args.double_blocks_to_swap or args.single_blocks_to_swap if is_swapping_blocks: # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! From bbd160b4ca9293881c222f9b9e1d832af69699db Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Wed, 18 Sep 2024 07:55:04 +0900 Subject: [PATCH 145/348] sd3 schedule free opt (#1605) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * New ScheduleFree support for Flux (#1600) * init * use no schedule * fix typo * update for eval() * fix typo * update * Update train_util.py * Update requirements.txt * update sfwrapper WIP * no need to check schedulefree optimizer * remove debug print * comment out schedulefree wrapper * update readme --------- Co-authored-by: 青龍聖者@bdsqlsz <865105819@qq.com> --- README.md | 8 +++ library/train_util.py | 152 ++++++++++++++++++++++++++++++++++++++++-- requirements.txt | 1 + 3 files changed, 154 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 24217d8b7..dc9862927 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,14 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 18, 2024: + +- Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details. + - `schedulefree` is added to the dependencies. Please update the library if necessary. + - AdamWScheduleFree or SGDScheduleFree can be used. Specify `adamwschedulefree` or `sgdschedulefree` in `--optimizer_type`. + - Wrapper classes are not available for now. + - These can be used not only for FLUX.1 training but also for other training scripts after merging to the dev/main branch. + Sep 16, 2024: Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details. diff --git a/library/train_util.py b/library/train_util.py index 60afd4219..a54f23ff6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3303,6 +3303,20 @@ def int_or_float(value): help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', ) + # parser.add_argument( + # "--optimizer_schedulefree_wrapper", + # action="store_true", + # help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用", + # ) + + # parser.add_argument( + # "--schedulefree_wrapper_args", + # type=str, + # default=None, + # nargs="*", + # help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")', + # ) + parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ") parser.add_argument( "--lr_scheduler_args", @@ -4582,26 +4596,146 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type.endswith("schedulefree".lower()): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + if optimizer_type == "AdamWScheduleFree".lower(): + optimizer_class = sf.AdamWScheduleFree + logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "SGDScheduleFree".lower(): + optimizer_class = sf.SGDScheduleFree + logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + # make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop + optimizer.train() + if optimizer is None: # 任意のoptimizerを使う - optimizer_type = args.optimizer_type # lowerでないやつ(微妙) - logger.info(f"use {optimizer_type} | {optimizer_kwargs}") - if "." not in optimizer_type: + case_sensitive_optimizer_type = args.optimizer_type # not lower + logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}") + + if "." not in case_sensitive_optimizer_type: # from torch.optim optimizer_module = torch.optim - else: - values = optimizer_type.split(".") + else: # from other library + values = case_sensitive_optimizer_type.split(".") optimizer_module = importlib.import_module(".".join(values[:-1])) - optimizer_type = values[-1] + case_sensitive_optimizer_type = values[-1] - optimizer_class = getattr(optimizer_module, optimizer_type) + optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + """ + # wrap any of above optimizer with schedulefree, if optimizer is not schedulefree + if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree".lower()): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + + schedulefree_wrapper_kwargs = {} + if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0: + for arg in args.schedulefree_wrapper_args: + key, value = arg.split("=") + value = ast.literal_eval(value) + schedulefree_wrapper_kwargs[key] = value + + sf_wrapper = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs) + sf_wrapper.train() # make optimizer as train mode + + # we need to make optimizer as a subclass of torch.optim.Optimizer, we make another Proxy class over SFWrapper + class OptimizerProxy(torch.optim.Optimizer): + def __init__(self, sf_wrapper): + self._sf_wrapper = sf_wrapper + + def __getattr__(self, name): + return getattr(self._sf_wrapper, name) + + # override properties + @property + def state(self): + return self._sf_wrapper.state + + @state.setter + def state(self, state): + self._sf_wrapper.state = state + + @property + def param_groups(self): + return self._sf_wrapper.param_groups + + @param_groups.setter + def param_groups(self, param_groups): + self._sf_wrapper.param_groups = param_groups + + @property + def defaults(self): + return self._sf_wrapper.defaults + + @defaults.setter + def defaults(self, defaults): + self._sf_wrapper.defaults = defaults + + def add_param_group(self, param_group): + self._sf_wrapper.add_param_group(param_group) + + def load_state_dict(self, state_dict): + self._sf_wrapper.load_state_dict(state_dict) + + def state_dict(self): + return self._sf_wrapper.state_dict() + + def zero_grad(self): + self._sf_wrapper.zero_grad() + + def step(self, closure=None): + self._sf_wrapper.step(closure) + + def train(self): + self._sf_wrapper.train() + + def eval(self): + self._sf_wrapper.eval() + + # isinstance チェックをパスするためのメソッド + def __instancecheck__(self, instance): + return isinstance(instance, (type(self), Optimizer)) + + optimizer = OptimizerProxy(sf_wrapper) + + logger.info(f"wrap optimizer with ScheduleFreeWrapper | {schedulefree_wrapper_kwargs}") + """ + optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) return optimizer_name, optimizer_args, optimizer +def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool: + return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper + + +def get_dummy_scheduler(optimizer: Optimizer) -> Any: + # dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers. + # this scheduler is used for logging only. + # this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler + class DummyScheduler: + def __init__(self, optimizer: Optimizer): + self.optimizer = optimizer + + def step(self): + pass + + def get_last_lr(self): + return [group["lr"] for group in self.optimizer.param_groups] + + return DummyScheduler(optimizer) + + # Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler # Add some checking and features to the original function. @@ -4610,6 +4744,10 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): """ Unified API to get any scheduler from its name. """ + # if schedulefree optimizer, return dummy scheduler + if is_schedulefree_optimizer(optimizer, args): + return get_dummy_scheduler(optimizer) + name = args.lr_scheduler num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps num_warmup_steps: Optional[int] = ( diff --git a/requirements.txt b/requirements.txt index 9a4fa0c15..bab53f20f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ pytorch-lightning==1.9.0 bitsandbytes==0.43.3 prodigyopt==1.0 lion-pytorch==0.0.6 +schedulefree==1.2.7 tensorboard safetensors==0.4.4 # gradio==3.16.2 From e74502117bcf161ef5698fb0adba4f9fa0171b8d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 18 Sep 2024 08:04:32 +0900 Subject: [PATCH 146/348] update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index dc9862927..034a260ff 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ The command to install PyTorch is as follows: Sep 18, 2024: - Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details. + - Details of the schedule-free optimizer can be found in [facebookresearch/schedule_free](https://github.com/facebookresearch/schedule_free). - `schedulefree` is added to the dependencies. Please update the library if necessary. - AdamWScheduleFree or SGDScheduleFree can be used. Specify `adamwschedulefree` or `sgdschedulefree` in `--optimizer_type`. - Wrapper classes are not available for now. From 1286e00bb0fc34c296f24b7057777f1c37cf8e11 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 18 Sep 2024 21:31:54 +0900 Subject: [PATCH 147/348] fix to call train/eval in schedulefree #1605 --- README.md | 3 +++ flux_train.py | 10 ++++++++++ library/train_util.py | 15 ++++++++++++++- train_network.py | 6 ++++++ 4 files changed, 33 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 034a260ff..843ae181b 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 18, 2024 (update 1): +Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now. + Sep 18, 2024: - Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details. diff --git a/flux_train.py b/flux_train.py index 5d8326b1d..bc4e62793 100644 --- a/flux_train.py +++ b/flux_train.py @@ -347,8 +347,13 @@ def train(args): logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset @@ -760,6 +765,7 @@ def optimizer_hook(parameter: torch.Tensor): progress_bar.update(1) global_step += 1 + optimizer_eval_fn() flux_train_utils.sample_images( accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) @@ -778,6 +784,7 @@ def optimizer_hook(parameter: torch.Tensor): global_step, accelerator.unwrap_model(flux), ) + optimizer_train_fn() current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if len(accelerator.trackers) > 0: @@ -800,6 +807,7 @@ def optimizer_hook(parameter: torch.Tensor): accelerator.wait_for_everyone() + optimizer_eval_fn() if args.save_every_n_epochs is not None: if accelerator.is_main_process: flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( @@ -816,12 +824,14 @@ def optimizer_hook(parameter: torch.Tensor): flux_train_utils.sample_images( accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) + optimizer_train_fn() is_main_process = accelerator.is_main_process # if is_main_process: flux = accelerator.unwrap_model(flux) accelerator.end_training() + optimizer_eval_fn() if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) diff --git a/library/train_util.py b/library/train_util.py index a54f23ff6..fe9deb940 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -13,6 +13,7 @@ import time from typing import ( Any, + Callable, Dict, List, NamedTuple, @@ -4715,8 +4716,20 @@ def __instancecheck__(self, instance): return optimizer_name, optimizer_args, optimizer +def get_optimizer_train_eval_fn(optimizer: Optimizer, args: argparse.Namespace) -> Tuple[Callable, Callable]: + if not is_schedulefree_optimizer(optimizer, args): + # return dummy func + return lambda: None, lambda: None + + # get train and eval functions from optimizer + train_fn = optimizer.train + eval_fn = optimizer.eval + + return train_fn, eval_fn + + def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool: - return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper + return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper def get_dummy_scheduler(optimizer: Optimizer) -> Any: diff --git a/train_network.py b/train_network.py index 34385ae08..55faa143e 100644 --- a/train_network.py +++ b/train_network.py @@ -498,6 +498,7 @@ def train(self, args): # accelerator.print(f"trainable_params: {k} = {v}") optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset @@ -1199,6 +1200,7 @@ def remove_model(old_ckpt_name): progress_bar.update(1) global_step += 1 + optimizer_eval_fn() self.sample_images( accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet ) @@ -1217,6 +1219,7 @@ def remove_model(old_ckpt_name): if remove_step_no is not None: remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) remove_model(remove_ckpt_name) + optimizer_train_fn() current_loss = loss.detach().item() loss_recorder.add(epoch=epoch, step=step, loss=current_loss) @@ -1243,6 +1246,7 @@ def remove_model(old_ckpt_name): accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 + optimizer_eval_fn() if args.save_every_n_epochs is not None: saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: @@ -1258,6 +1262,7 @@ def remove_model(old_ckpt_name): train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + optimizer_train_fn() # end of epoch @@ -1268,6 +1273,7 @@ def remove_model(old_ckpt_name): network = accelerator.unwrap_model(network) accelerator.end_training() + optimizer_eval_fn() if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) From 3957372ded6fda20553acaf169993a422b829bdc Mon Sep 17 00:00:00 2001 From: Ed McManus Date: Thu, 19 Sep 2024 14:30:03 -0700 Subject: [PATCH 148/348] Retain alpha in `pil_resize` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently the alpha channel is dropped by `pil_resize()` when `--alpha_mask` is supplied and the image width does not exceed the bucket. This codepath is entered on the last line, here: ``` def trim_and_resize_if_required( random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int] ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: image_height, image_width = image.shape[0:2] original_size = (image_width, image_height) # size before resize if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする if image_width > resized_size[0] and image_height > resized_size[1]: image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ else: image = pil_resize(image, resized_size) ``` --- library/utils.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/library/utils.py b/library/utils.py index a0bb19650..2171c7190 100644 --- a/library/utils.py +++ b/library/utils.py @@ -305,13 +305,26 @@ def _convert_float8(byte_tensor, dtype_str, shape): raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") def pil_resize(image, size, interpolation=Image.LANCZOS): - pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + # Check if the image has an alpha channel + has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False - # use Pillow resize + if has_alpha: + # Convert BGRA to RGBA + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) + else: + # Convert BGR to RGB + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # Resize the image resized_pil = pil_image.resize(size, interpolation) - # return cv2 image - resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) + # Convert back to cv2 format + if has_alpha: + # Convert RGBA to BGRA + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) + else: + # Convert RGB to BGR + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) return resized_cv2 From de4bb657b089cc28f4127e891b927895892e20b5 Mon Sep 17 00:00:00 2001 From: Ed McManus Date: Thu, 19 Sep 2024 14:38:32 -0700 Subject: [PATCH 149/348] Update utils.py Cleanup --- library/utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/library/utils.py b/library/utils.py index 2171c7190..8a0c782c0 100644 --- a/library/utils.py +++ b/library/utils.py @@ -305,25 +305,19 @@ def _convert_float8(byte_tensor, dtype_str, shape): raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") def pil_resize(image, size, interpolation=Image.LANCZOS): - # Check if the image has an alpha channel has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False if has_alpha: - # Convert BGRA to RGBA pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) else: - # Convert BGR to RGB pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - # Resize the image resized_pil = pil_image.resize(size, interpolation) # Convert back to cv2 format if has_alpha: - # Convert RGBA to BGRA resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) else: - # Convert RGB to BGR resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) return resized_cv2 From 0535cd29b926530255d5400374813432ec52c3df Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Fri, 20 Sep 2024 10:05:22 +0800 Subject: [PATCH 150/348] fix: backward compatibility for text_encoder_lr --- train_network.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 55faa143e..dfa51a9c8 100644 --- a/train_network.py +++ b/train_network.py @@ -471,7 +471,11 @@ def train(self, args): if support_multiple_lrs: text_encoder_lr = args.text_encoder_lr else: - text_encoder_lr = None if args.text_encoder_lr is None or len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] + # toml backward compatibility + if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float): + text_encoder_lr = args.text_encoder_lr + else: + text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] try: if support_multiple_lrs: results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate) From 583d4a436c1cef57fce405d0167fb7ce575fc768 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 20 Sep 2024 22:22:24 +0900 Subject: [PATCH 151/348] add compatibility for int LR (D-Adaptation etc.) #1620 --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index dfa51a9c8..b24f89b1e 100644 --- a/train_network.py +++ b/train_network.py @@ -472,7 +472,7 @@ def train(self, args): text_encoder_lr = args.text_encoder_lr else: # toml backward compatibility - if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float): + if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float) or isinstance(args.text_encoder_lr, int): text_encoder_lr = args.text_encoder_lr else: text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] From 56a7bc171d48089fb50f8638537e42d07c579db3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 08:26:31 +0900 Subject: [PATCH 152/348] new block swap for FLUX.1 fine tuning --- README.md | 47 ++++++-- flux_train.py | 251 ++++++++++++++++++++++++++--------------- library/flux_models.py | 168 +++++++++++++++------------ 3 files changed, 297 insertions(+), 169 deletions(-) diff --git a/README.md b/README.md index ef691e918..7d623f900 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 26, 2024: +The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. + + Sep 18, 2024 (update 1): Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now. @@ -307,6 +311,8 @@ python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_ The memory-efficient training with block swap is based on 2kpr's implementation. Thanks to 2kpr! +__`--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. These options is still available, but they will be removed in the future. Please use `--blocks_to_swap` instead. These options are equivalent to specifying `double_blocks_to_swap + single_blocks_to_swap // 2` in `--blocks_to_swap`.__ + Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended. ``` @@ -319,39 +325,62 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 --timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 ---fused_backward_pass --double_blocks_to_swap 6 --cpu_offload_checkpointing --full_bf16 +--fused_backward_pass --blocks_to_swap 8 --full_bf16 ``` (The command is multi-line for readability. Please combine it into one line.) -Options are almost the same as LoRA training. The difference is `--full_bf16`, `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. +Options are almost the same as LoRA training. The difference is `--full_bf16`, `--fused_backward_pass` and `--blocks_to_swap`. `--cpu_offload_checkpointing` is also available. `--full_bf16` enables the training with bf16 (weights and gradients). `--fused_backward_pass` enables the fusing of the optimizer step into the backward pass for each parameter. This reduces the memory usage during training. Only Adafactor optimizer is supported for now. Stochastic rounding is also enabled when `--fused_backward_pass` and `--full_bf16` are specified. -`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. +`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency and stochastic rounding. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. -`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. Please see the next chapter for details. +`--blocks_to_swap` is the number of blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. The recommended maximum value is 36. -`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. +`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. All these options are experimental and may change in the future. The increasing the number of blocks to swap may reduce the memory usage, but the training speed will be slower. `--cpu_offload_checkpointing` also slows down the training. -Swap 6 double blocks and use cpu offload checkpointing may be a good starting point. Please try different settings according to VRAM usage and training speed. +Swap 8 blocks without cpu offload checkpointing may be a good starting point for 24GB VRAM GPUs. Please try different settings according to VRAM usage and training speed. The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results. +#### How to use block swap + +There are two possible ways to use block swap. It is unknown which is better. + +1. Swap the minimum number of blocks that fit in VRAM with batch size 1 and shorten the training speed of one step. + + The above command example is for this usage. + +2. Swap many blocks to increase the batch size and shorten the training speed per data. + + For example, swapping 20 blocks seems to increase the batch size to about 6. In this case, the training speed per data will be relatively faster than 1. + +#### Training with <24GB VRAM GPUs + +Swap 28 blocks without cpu offload checkpointing may be working with 12GB VRAM GPUs. Please try different settings according to VRAM size of your GPU. + +T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement for FLUX.1 fine-tuning. + #### Key Features for FLUX.1 fine-tuning -1. Technical details of double/single block swap: +1. Technical details of block swap: - Reduce memory usage by transferring double and single blocks of FLUX.1 from GPU to CPU when they are not needed. - During forward pass, the weights of the blocks that have finished calculation are transferred to CPU, and the weights of the blocks to be calculated are transferred to GPU. - The same is true for the backward pass, but the order is reversed. The gradients remain on the GPU. - Since the transfer between CPU and GPU takes time, the training will be slower. - - `--double_blocks_to_swap` and `--single_blocks_to_swap` specify the number of blocks to swap. For example, `--double_blocks_to_swap 6` swaps 6 blocks at each step of training, but the remaining 13 blocks are always on the GPU. - - About 640MB of memory can be saved per double block, and about 320MB of memory can be saved per single block. + - `--blocks_to_swap` specify the number of blocks to swap. + - About 640MB of memory can be saved per block. + - Since the memory usage of one double block and two single blocks is almost the same, the transfer of single blocks is done in units of two. For example, consider the case of `--blocks_to_swap 6`. + - Before the forward pass, all double blocks and 26 (=38-12) single blocks are on the GPU. The last 12 single blocks are on the CPU. + - In the forward pass, the 6 double blocks that have finished calculation (the first 6 blocks) are transferred to the CPU, and the 12 single blocks to be calculated (the last 12 blocks) are transferred to the GPU. + - The same is true for the backward pass, but in reverse order. The 12 single blocks that have finished calculation are transferred to the CPU, and the 6 double blocks to be calculated are transferred to the GPU. + - After the backward pass, the blocks are back to their original locations. 2. Sample Image Generation: - Sample image generation during training is now supported. diff --git a/flux_train.py b/flux_train.py index bc4e62793..bf34208f1 100644 --- a/flux_train.py +++ b/flux_train.py @@ -11,10 +11,12 @@ # - Per-block fused optimizer instances import argparse +from concurrent.futures import ThreadPoolExecutor import copy import math import os from multiprocessing import Value +import time from typing import List import toml @@ -265,14 +267,30 @@ def train(args): flux.requires_grad_(True) - is_swapping_blocks = args.double_blocks_to_swap or args.single_blocks_to_swap + # block swap + + # backward compatibility + if args.blocks_to_swap is None: + blocks_to_swap = args.double_blocks_to_swap or 0 + if args.single_blocks_to_swap is not None: + blocks_to_swap += args.single_blocks_to_swap // 2 + if blocks_to_swap > 0: + logger.warning( + "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + ) + logger.info( + f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + ) + args.blocks_to_swap = blocks_to_swap + del blocks_to_swap + + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 if is_swapping_blocks: # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! - logger.info( - f"enable block swap: double_blocks_to_swap={args.double_blocks_to_swap}, single_blocks_to_swap={args.single_blocks_to_swap}" - ) - flux.enable_block_swap(args.double_blocks_to_swap, args.single_blocks_to_swap) + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + flux.enable_block_swap(args.blocks_to_swap) if not cache_latents: # load VAE here if not cached @@ -443,82 +461,120 @@ def train(args): # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) + # memory efficient block swapping + + def get_block_unit(dbl_blocks, sgl_blocks, index: int): + if index < len(dbl_blocks): + return (dbl_blocks[index],) + else: + index -= len(dbl_blocks) + index *= 2 + return (sgl_blocks[index], sgl_blocks[index + 1]) + + def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device): + def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc): + # print(f"Backward: Move block {bidx_to_cpu} to CPU") + for block in blocks_to_cpu: + block = block.to("cpu", non_blocking=True) + torch.cuda.empty_cache() + + # print(f"Backward: Move block {bidx_to_cuda} to CUDA") + for block in blocks_to_cuda: + block = block.to(dvc, non_blocking=True) + + torch.cuda.synchronize() + # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}") + return bidx_to_cpu, bidx_to_cuda + + blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu) + blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda) + + futures[block_idx_to_cuda] = thread_pool.submit( + move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device + ) + + def wait_blocks_move(block_idx, futures): + if block_idx not in futures: + return + # print(f"Backward: Wait for block {block_idx}") + # start_time = time.perf_counter() + future = futures.pop(block_idx) + future.result() + # print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") + # torch.cuda.synchronize() + # print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s") + if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - double_blocks_to_swap = args.double_blocks_to_swap - single_blocks_to_swap = args.single_blocks_to_swap + blocks_to_swap = args.blocks_to_swap num_double_blocks = 19 # len(flux.double_blocks) num_single_blocks = 38 # len(flux.single_blocks) - handled_double_block_indices = set() - handled_single_block_indices = set() + num_block_units = num_double_blocks + num_single_blocks // 2 + handled_unit_indices = set() + + n = 1 # only asyncronous purpose, no need to increase this number + # n = 2 + # n = max(1, os.cpu_count() // 2) + thread_pool = ThreadPoolExecutor(max_workers=n) + futures = {} for param_group, param_name_group in zip(optimizer.param_groups, param_names): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: grad_hook = None - if double_blocks_to_swap: - if param_name.startswith("double_blocks"): - block_idx = int(param_name.split(".")[1]) - if ( - block_idx not in handled_double_block_indices - and block_idx >= (num_double_blocks - double_blocks_to_swap) - 1 - and block_idx < num_double_blocks - 1 - ): - # swap next (already backpropagated) block - handled_double_block_indices.add(block_idx) - block_idx_cpu = block_idx + 1 - block_idx_cuda = double_blocks_to_swap - (num_double_blocks - block_idx_cpu) - - # create swap hook - def create_double_swap_grad_hook(bidx, bidx_cuda): - def __grad_hook(tensor: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - - # swap blocks if necessary - flux.double_blocks[bidx].to("cpu") - flux.double_blocks[bidx_cuda].to(accelerator.device) - # print(f"Move double block {bidx} to cpu and {bidx_cuda} to device") - - return __grad_hook - - grad_hook = create_double_swap_grad_hook(block_idx_cpu, block_idx_cuda) - if single_blocks_to_swap: - if param_name.startswith("single_blocks"): + if blocks_to_swap: + is_double = param_name.startswith("double_blocks") + is_single = param_name.startswith("single_blocks") + if is_double or is_single: block_idx = int(param_name.split(".")[1]) - if ( - block_idx not in handled_single_block_indices - and block_idx >= (num_single_blocks - single_blocks_to_swap) - 1 - and block_idx < num_single_blocks - 1 - ): - handled_single_block_indices.add(block_idx) - block_idx_cpu = block_idx + 1 - block_idx_cuda = single_blocks_to_swap - (num_single_blocks - block_idx_cpu) - # print(param_name, block_idx_cpu, block_idx_cuda) - - # create swap hook - def create_single_swap_grad_hook(bidx, bidx_cuda): - def __grad_hook(tensor: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - - # swap blocks if necessary - flux.single_blocks[bidx].to("cpu") - flux.single_blocks[bidx_cuda].to(accelerator.device) - # print(f"Move single block {bidx} to cpu and {bidx_cuda} to device") - - return __grad_hook - - grad_hook = create_single_swap_grad_hook(block_idx_cpu, block_idx_cuda) + unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2 + if unit_idx not in handled_unit_indices: + # swap following (already backpropagated) block + handled_unit_indices.add(unit_idx) + + # if n blocks were already backpropagated + num_blocks_propagated = num_block_units - unit_idx - 1 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap + waiting = unit_idx > 0 and unit_idx <= blocks_to_swap + if swapping or waiting: + block_idx_to_cpu = num_block_units - num_blocks_propagated + block_idx_to_cuda = blocks_to_swap - num_blocks_propagated + block_idx_to_wait = unit_idx - 1 + + # create swap hook + def create_swap_grad_hook( + bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool + ): + def __grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + # print(f"Backward: {uidx}, {swpng}, {wtng}") + if swpng: + submit_move_blocks( + futures, + thread_pool, + bidx_to_cpu, + bidx_to_cuda, + flux.double_blocks, + flux.single_blocks, + accelerator.device, + ) + if wtng: + wait_blocks_move(bidx_to_wait, futures) + + return __grad_hook + + grad_hook = create_swap_grad_hook( + block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting + ) if grad_hook is None: @@ -547,10 +603,15 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} - double_blocks_to_swap = args.double_blocks_to_swap - single_blocks_to_swap = args.single_blocks_to_swap + blocks_to_swap = args.blocks_to_swap num_double_blocks = 19 # len(flux.double_blocks) num_single_blocks = 38 # len(flux.single_blocks) + num_block_units = num_double_blocks + num_single_blocks // 2 + + n = 1 # only asyncronous purpose, no need to increase this number + # n = max(1, os.cpu_count() // 2) + thread_pool = ThreadPoolExecutor(max_workers=n) + futures = {} for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: @@ -571,18 +632,30 @@ def optimizer_hook(parameter: torch.Tensor): optimizers[i].zero_grad(set_to_none=True) # swap blocks if necessary - if btype == "double" and double_blocks_to_swap: - if bidx >= num_double_blocks - double_blocks_to_swap: - bidx_cuda = double_blocks_to_swap - (num_double_blocks - bidx) - flux.double_blocks[bidx].to("cpu") - flux.double_blocks[bidx_cuda].to(accelerator.device) - # print(f"Move double block {bidx} to cpu and {bidx_cuda} to device") - elif btype == "single" and single_blocks_to_swap: - if bidx >= num_single_blocks - single_blocks_to_swap: - bidx_cuda = single_blocks_to_swap - (num_single_blocks - bidx) - flux.single_blocks[bidx].to("cpu") - flux.single_blocks[bidx_cuda].to(accelerator.device) - # print(f"Move single block {bidx} to cpu and {bidx_cuda} to device") + if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)): + unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2 + num_blocks_propagated = num_block_units - unit_idx + + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap + waiting = unit_idx > 0 and unit_idx <= blocks_to_swap + + if swapping: + block_idx_to_cpu = num_block_units - num_blocks_propagated + block_idx_to_cuda = blocks_to_swap - num_blocks_propagated + # print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}") + submit_move_blocks( + futures, + thread_pool, + block_idx_to_cpu, + block_idx_to_cuda, + flux.double_blocks, + flux.single_blocks, + accelerator.device, + ) + + if waiting: + block_idx_to_wait = unit_idx - 1 + wait_blocks_move(block_idx_to_wait, futures) return optimizer_hook @@ -881,24 +954,26 @@ def setup_parser() -> argparse.ArgumentParser: help="skip latents validity check / latentsの正当性チェックをスキップする", ) parser.add_argument( - "--double_blocks_to_swap", + "--blocks_to_swap", type=int, default=None, help="[EXPERIMENTAL] " - "Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes." + "Sets the number of blocks (~640MB) to swap during the forward and backward passes." "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップする'変換ブロック'(約640MB)の数を設定します。" + " / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。" "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) parser.add_argument( "--single_blocks_to_swap", type=int, default=None, - help="[EXPERIMENTAL] " - "Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes." - "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップする'変換ブロック'(約320MB)の数を設定します。" - "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", ) parser.add_argument( "--cpu_offload_checkpointing", diff --git a/library/flux_models.py b/library/flux_models.py index b5726c298..a35dbc106 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -2,9 +2,12 @@ # license: Apache-2.0 License +from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass import math -from typing import Optional +import os +import time +from typing import Dict, List, Optional from library.device_utils import init_ipex, clean_memory_on_device @@ -917,8 +920,10 @@ def __init__(self, params: FluxParams): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False - self.double_blocks_to_swap = None - self.single_blocks_to_swap = None + self.blocks_to_swap = None + + self.thread_pool: Optional[ThreadPoolExecutor] = None + self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2 @property def device(self): @@ -956,38 +961,52 @@ def disable_gradient_checkpointing(self): print("FLUX: Gradient checkpointing disabled.") - def enable_block_swap(self, double_blocks: Optional[int], single_blocks: Optional[int]): - self.double_blocks_to_swap = double_blocks - self.single_blocks_to_swap = single_blocks + def enable_block_swap(self, num_blocks: int): + self.blocks_to_swap = num_blocks + + n = 1 # async block swap. 1 is enough + # n = 2 + # n = max(1, os.cpu_count() // 2) + self.thread_pool = ThreadPoolExecutor(max_workers=n) def move_to_device_except_swap_blocks(self, device: torch.device): # assume model is on cpu - if self.double_blocks_to_swap: + if self.blocks_to_swap: save_double_blocks = self.double_blocks - self.double_blocks = None - if self.single_blocks_to_swap: save_single_blocks = self.single_blocks + self.double_blocks = None self.single_blocks = None self.to(device) - if self.double_blocks_to_swap: + if self.blocks_to_swap: self.double_blocks = save_double_blocks - if self.single_blocks_to_swap: self.single_blocks = save_single_blocks + def get_block_unit(self, index: int): + if index < len(self.double_blocks): + return (self.double_blocks[index],) + else: + index -= len(self.double_blocks) + index *= 2 + return self.single_blocks[index], self.single_blocks[index + 1] + + def get_unit_index(self, is_double: bool, index: int): + if is_double: + return index + else: + return len(self.double_blocks) + index // 2 + def prepare_block_swap_before_forward(self): - # move last n blocks to cpu: they are on cuda - if self.double_blocks_to_swap: - for i in range(len(self.double_blocks) - self.double_blocks_to_swap): - self.double_blocks[i].to(self.device) - for i in range(len(self.double_blocks) - self.double_blocks_to_swap, len(self.double_blocks)): - self.double_blocks[i].to("cpu") # , non_blocking=True) - if self.single_blocks_to_swap: - for i in range(len(self.single_blocks) - self.single_blocks_to_swap): - self.single_blocks[i].to(self.device) - for i in range(len(self.single_blocks) - self.single_blocks_to_swap, len(self.single_blocks)): - self.single_blocks[i].to("cpu") # , non_blocking=True) + # make: first n blocks are on cuda, and last n blocks are on cpu + if self.blocks_to_swap is None: + raise ValueError("Block swap is not enabled.") + for i in range(self.num_block_units - self.blocks_to_swap): + for b in self.get_block_unit(i): + b.to(self.device) + for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): + for b in self.get_block_unit(i): + b.to("cpu") clean_memory_on_device(self.device) def forward( @@ -1017,69 +1036,73 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) - if not self.double_blocks_to_swap: + if not self.blocks_to_swap: for block in self.double_blocks: img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: - # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning - for block_idx in range(self.double_blocks_to_swap): - block = self.double_blocks[len(self.double_blocks) - self.double_blocks_to_swap + block_idx] - if block.parameters().__next__().device.type != "cpu": - block.to("cpu") # , non_blocking=True) - # print(f"Moved double block {len(self.double_blocks) - self.double_blocks_to_swap + block_idx} to cpu.") - - block = self.double_blocks[block_idx] - if block.parameters().__next__().device.type == "cpu": - block.to(self.device) - # print(f"Moved double block {block_idx} to cuda.") - - to_cpu_block_index = 0 + futures = {} + + def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda): + # print(f"Moving {bidx_to_cpu} to cpu.") + for block in blocks_to_cpu: + block.to("cpu", non_blocking=True) + torch.cuda.empty_cache() + + # print(f"Moving {bidx_to_cuda} to cuda.") + for block in blocks_to_cuda: + block.to(self.device, non_blocking=True) + + torch.cuda.synchronize() + # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") + return block_idx_to_cpu, block_idx_to_cuda + + blocks_to_cpu = self.get_block_unit(block_idx_to_cpu) + blocks_to_cuda = self.get_block_unit(block_idx_to_cuda) + # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") + return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda) + + def wait_for_blocks_move(block_idx, ftrs): + if block_idx not in ftrs: + return + # print(f"Waiting for move blocks: {block_idx}") + # start_time = time.perf_counter() + ftr = ftrs.pop(block_idx) + ftr.result() + # torch.cuda.synchronize() + # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") + for block_idx, block in enumerate(self.double_blocks): - # move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda - moving = block_idx >= len(self.double_blocks) - self.double_blocks_to_swap - if moving: - block.to(self.device) # move to cuda - # print(f"Moved double block {block_idx} to cuda.") + # print(f"Double block {block_idx}") + unit_idx = self.get_unit_index(is_double=True, index=block_idx) + wait_for_blocks_move(unit_idx, futures) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if moving: - self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) - # print(f"Moved double block {to_cpu_block_index} to cpu.") - to_cpu_block_index += 1 + if unit_idx < self.blocks_to_swap: + block_idx_to_cpu = unit_idx + block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx + future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) + futures[block_idx_to_cuda] = future - img = torch.cat((txt, img), 1) + img = torch.cat((txt, img), 1) - if not self.single_blocks_to_swap: - for block in self.single_blocks: - img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - else: - # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning - for block_idx in range(self.single_blocks_to_swap): - block = self.single_blocks[len(self.single_blocks) - self.single_blocks_to_swap + block_idx] - if block.parameters().__next__().device.type != "cpu": - block.to("cpu") # , non_blocking=True) - # print(f"Moved single block {len(self.single_blocks) - self.single_blocks_to_swap + block_idx} to cpu.") - - block = self.single_blocks[block_idx] - if block.parameters().__next__().device.type == "cpu": - block.to(self.device) - # print(f"Moved single block {block_idx} to cuda.") - - to_cpu_block_index = 0 for block_idx, block in enumerate(self.single_blocks): - # move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda - moving = block_idx >= len(self.single_blocks) - self.single_blocks_to_swap - if moving: - block.to(self.device) # move to cuda - # print(f"Moved single block {block_idx} to cuda.") + # print(f"Single block {block_idx}") + unit_idx = self.get_unit_index(is_double=False, index=block_idx) + if block_idx % 2 == 0: + wait_for_blocks_move(unit_idx, futures) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if moving: - self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) - # print(f"Moved single block {to_cpu_block_index} to cpu.") - to_cpu_block_index += 1 + if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap: + block_idx_to_cpu = unit_idx + block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx + future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) + futures[block_idx_to_cuda] = future img = img[:, txt.shape[1] :, ...] @@ -1088,6 +1111,7 @@ def forward( vec = vec.to(self.device) img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img From da94fd934eb4951d1cb132abc9d2a355e44d7abf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 08:27:48 +0900 Subject: [PATCH 153/348] fix typos --- flux_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train.py b/flux_train.py index bf34208f1..022467ea7 100644 --- a/flux_train.py +++ b/flux_train.py @@ -516,7 +516,7 @@ def wait_blocks_move(block_idx, futures): num_block_units = num_double_blocks + num_single_blocks // 2 handled_unit_indices = set() - n = 1 # only asyncronous purpose, no need to increase this number + n = 1 # only asynchronous purpose, no need to increase this number # n = 2 # n = max(1, os.cpu_count() // 2) thread_pool = ThreadPoolExecutor(max_workers=n) @@ -608,7 +608,7 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): num_single_blocks = 38 # len(flux.single_blocks) num_block_units = num_double_blocks + num_single_blocks // 2 - n = 1 # only asyncronous purpose, no need to increase this number + n = 1 # only asynchronous purpose, no need to increase this number # n = max(1, os.cpu_count() // 2) thread_pool = ThreadPoolExecutor(max_workers=n) futures = {} From 392e8dedd84e469b125e2935e3ecf02e6270a5b2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 21:14:11 +0900 Subject: [PATCH 154/348] fix flip_aug, alpha_mask, random_crop issue in caching in caching strategy --- library/train_util.py | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 319337a47..17dd447eb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -993,9 +993,26 @@ def new_cache_latents(self, model: Any, is_main_process: bool): # sort by resolution image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) - # split by resolution - batches = [] - batch = [] + # split by resolution and some conditions + class Condition: + def __init__(self, reso, flip_aug, alpha_mask, random_crop): + self.reso = reso + self.flip_aug = flip_aug + self.alpha_mask = alpha_mask + self.random_crop = random_crop + + def __eq__(self, other): + return ( + self.reso == other.reso + and self.flip_aug == other.flip_aug + and self.alpha_mask == other.alpha_mask + and self.random_crop == other.random_crop + ) + + batches: List[Tuple[Condition, List[ImageInfo]]] = [] + batch: List[ImageInfo] = [] + current_condition = None + logger.info("checking cache validity...") for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] @@ -1016,20 +1033,23 @@ def new_cache_latents(self, model: Any, is_main_process: bool): if cache_available: # do not add to batch continue - # if last member of batch has different resolution, flush the batch - if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: - batches.append(batch) + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + batches.append((current_condition, batch)) batch = [] batch.append(info) + current_condition = condition # if number of data in batch is enough, flush the batch if len(batch) >= caching_strategy.batch_size: - batches.append(batch) + batches.append((current_condition, batch)) batch = [] + current_condition = None if len(batch) > 0: - batches.append(batch) + batches.append((current_condition, batch)) # if cache to disk, don't cache latents in non-main process, set to info only if caching_strategy.cache_to_disk and not is_main_process: @@ -1041,9 +1061,8 @@ def new_cache_latents(self, model: Any, is_main_process: bool): # iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded logger.info("caching latents...") - for batch in tqdm(batches, smoothing=1, total=len(batches)): - # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) - caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): + caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと From 9249d00311002c84b189c2f6792cbe7aa344a1d5 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Thu, 26 Sep 2024 22:19:56 +0900 Subject: [PATCH 155/348] experimental support for multi-gpus latents caching --- library/train_util.py | 27 ++++++++++++++++----------- train_network.py | 2 +- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 3768b6051..2ca662dcb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -981,7 +981,7 @@ def is_text_encoder_output_cacheable(self): ] ) - def new_cache_latents(self, model: Any, is_main_process: bool): + def new_cache_latents(self, model: Any, accelerator: Accelerator): r""" a brand new method to cache latents. This method caches latents with caching strategy. normal cache_latents method is used by default, but this method is used when caching strategy is specified. @@ -1013,8 +1013,12 @@ def __eq__(self, other): batch: List[ImageInfo] = [] current_condition = None + # support multiple-gpus + num_processes = accelerator.num_processes + process_index = accelerator.process_index + logger.info("checking cache validity...") - for info in tqdm(image_infos): + for i, info in enumerate(tqdm(image_infos)): subset = self.image_to_subset[info.image_key] if info.latents_npz is not None: # fine tuning dataset @@ -1024,9 +1028,14 @@ def __eq__(self, other): if caching_strategy.cache_to_disk: # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) - if not is_main_process: # prepare for multi-gpu, only store to info + + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different latents + if i % num_processes != process_index: continue + print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") + cache_available = caching_strategy.is_disk_cached_latents_expected( info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask ) @@ -1051,10 +1060,6 @@ def __eq__(self, other): if len(batch) > 0: batches.append((current_condition, batch)) - # if cache to disk, don't cache latents in non-main process, set to info only - if caching_strategy.cache_to_disk and not is_main_process: - return - if len(batches) == 0: logger.info("no latents to cache") return @@ -2258,8 +2263,8 @@ def make_buckets(self): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) - def new_cache_latents(self, model: Any, is_main_process: bool): - return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process) + def new_cache_latents(self, model: Any, accelerator: Accelerator): + return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator) def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process) @@ -2363,10 +2368,10 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc logger.info(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) - def new_cache_latents(self, model: Any, is_main_process: bool): + def new_cache_latents(self, model: Any, accelerator: Accelerator): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.new_cache_latents(model, is_main_process) + dataset.new_cache_latents(model, accelerator) def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True diff --git a/train_network.py b/train_network.py index b24f89b1e..7eb7aa49c 100644 --- a/train_network.py +++ b/train_network.py @@ -384,7 +384,7 @@ def train(self, args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) From 24b1fdb66485af70b3c79feaf8ff1a348b66668e Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Thu, 26 Sep 2024 22:22:06 +0900 Subject: [PATCH 156/348] remove debug print --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 2ca662dcb..8d6164b1b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1031,10 +1031,10 @@ def __eq__(self, other): # if the modulo of num_processes is not equal to process_index, skip caching # this makes each process cache different latents - if i % num_processes != process_index: + if i % num_processes != process_index: continue - print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") + # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") cache_available = caching_strategy.is_disk_cached_latents_expected( info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask From a9aa52658a0d9ba7910a1d1983b650bc9de7153e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 28 Sep 2024 17:12:56 +0900 Subject: [PATCH 157/348] fix sample generation is not working in FLUX1 fine tuning #1647 --- library/flux_models.py | 5 +++-- library/flux_train_utils.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/library/flux_models.py b/library/flux_models.py index a35dbc106..0bc1c02b9 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -999,8 +999,9 @@ def get_unit_index(self, is_double: bool, index: int): def prepare_block_swap_before_forward(self): # make: first n blocks are on cuda, and last n blocks are on cpu - if self.blocks_to_swap is None: - raise ValueError("Block swap is not enabled.") + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + # raise ValueError("Block swap is not enabled.") + return for i in range(self.num_block_units - self.blocks_to_swap): for b in self.get_block_unit(i): b.to(self.device) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f77d4b585..1d1eb9d24 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -313,6 +313,7 @@ def denoise( guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + model.prepare_block_swap_before_forward() pred = model( img=img, img_ids=img_ids, @@ -325,7 +326,8 @@ def denoise( ) img = img + (t_prev - t_curr) * pred - + + model.prepare_block_swap_before_forward() return img From 822fe578591e44ac949830e03a8841e222483052 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 28 Sep 2024 20:57:27 +0900 Subject: [PATCH 158/348] add workaround for 'Some tensors share memory' error #1614 --- networks/convert_flux_lora.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py index bd4c1cf78..fe6466ebc 100644 --- a/networks/convert_flux_lora.py +++ b/networks/convert_flux_lora.py @@ -412,6 +412,10 @@ def main(args): state_dict = convert_ai_toolkit_to_sd_scripts(state_dict) elif args.src == "sd-scripts" and args.dst == "ai-toolkit": state_dict = convert_sd_scripts_to_ai_toolkit(state_dict) + + # eliminate 'shared tensors' + for k in list(state_dict.keys()): + state_dict[k] = state_dict[k].detach().clone() else: raise NotImplementedError(f"Conversion from {args.src} to {args.dst} is not supported") From 1a0f5b0c389f4e9fab5edb06b36f203e8894d581 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 00:35:29 +0900 Subject: [PATCH 159/348] re-fix sample generation is not working in FLUX1 split mode #1647 --- flux_train_network.py | 3 +++ library/flux_train_utils.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/flux_train_network.py b/flux_train_network.py index a6e57eede..65b121e7c 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -300,6 +300,9 @@ def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.Fl self.flux_lower = flux_lower self.target_device = device + def prepare_block_swap_before_forward(self): + pass + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): self.flux_lower.to("cpu") clean_memory_on_device(self.target_device) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 1d1eb9d24..b3c9184f2 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -196,7 +196,6 @@ def sample_image_inference( tokens_and_masks = tokenize_strategy.tokenize(prompt) # strategy has apply_t5_attn_mask option encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) - print([x.shape if x is not None else None for x in encoded_text_encoder_conds]) # if text_encoder_conds is not cached, use encoded_text_encoder_conds if len(text_encoder_conds) == 0: From e0c3630203776dc568c32d67806a0a9d443f5721 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Sun, 29 Sep 2024 09:11:15 +0800 Subject: [PATCH 160/348] Support Sdxl Controlnet (#1648) * Create sdxl_train_controlnet.py * add fuse_background_pass * Update sdxl_train_controlnet.py * add fuse and fix error * update * Update sdxl_train_controlnet.py * Update sdxl_train_controlnet.py * Update sdxl_train_controlnet.py * update * Update sdxl_train_controlnet.py --- library/train_util.py | 2 +- sdxl_train_controlnet.py | 752 +++++++++++++++++++++++++++++++++++++++ train_controlnet.py | 33 +- 3 files changed, 779 insertions(+), 8 deletions(-) create mode 100644 sdxl_train_controlnet.py diff --git a/library/train_util.py b/library/train_util.py index e023f63a2..293fc05ad 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3581,7 +3581,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: # available backends: # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 # https://pytorch.org/docs/stable/torch.compiler.html - choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt", "tensort", "ipex", "tvm"], help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") diff --git a/sdxl_train_controlnet.py b/sdxl_train_controlnet.py new file mode 100644 index 000000000..00026d2cc --- /dev/null +++ b/sdxl_train_controlnet.py @@ -0,0 +1,752 @@ +import argparse +import math +import os +import random +from multiprocessing import Value +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from torch.nn.parallel import DistributedDataParallel as DDP +from accelerate.utils import set_seed +from diffusers import DDPMScheduler, ControlNetModel +from diffusers.utils.torch_utils import is_compiled_module +from safetensors.torch import load_file +from library import ( + deepspeed_utils, + sai_model_spec, + sdxl_model_util, + sdxl_original_unet, + sdxl_train_util, +) + +import library.model_util as model_util +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.huggingface_util as huggingface_util +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + add_v_prediction_like_loss, + apply_snr_weight, + prepare_scheduler_for_custom_training, + scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, +) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +# TODO 他のスクリプトと共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = { + "loss/current": current_loss, + "loss/average": avr_loss, + "lr": lr_scheduler.get_last_lr()[0], + } + + if args.optimizer_type.lower().startswith("DAdapt".lower()): + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[-1].param_groups[0]["d"] + * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + ) + + return logs + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) + + cache_latents = args.cache_latents + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if use_user_config: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension, + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = (train_dataset_group if args.max_data_loader_n_workers == 0 else None) + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(32) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + else: + logger.warning( + "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" + ) + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model( + args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype + ) + + # convert U-Net + with torch.no_grad(): + du_unet_sd = sdxl_model_util.convert_sdxl_unet_state_dict_to_diffusers(unet.state_dict()) + unet.to("cpu") + clean_memory_on_device(accelerator.device) + del unet + unet = sdxl_model_util.UNet2DConditionModel(**sdxl_model_util.DIFFUSERS_SDXL_UNET_CONFIG) + unet.load_state_dict(du_unet_sd) + + controlnet = ControlNetModel.from_unet(unet) + + if args.controlnet_model_name_or_path: + filename = args.controlnet_model_name_or_path + if os.path.isfile(filename): + if os.path.splitext(filename)[1] == ".safetensors": + state_dict = load_file(filename) + else: + state_dict = torch.load(filename) + state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) + controlnet.load_state_dict(state_dict) + elif os.path.isdir(filename): + controlnet = ControlNetModel.from_pretrained(filename) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents( + vae, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + ) + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + with torch.no_grad(): + train_dataset_group.cache_text_encoder_outputs( + (tokenizer1, tokenizer2), + (text_encoder1, text_encoder2), + accelerator.device, + None, + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) + accelerator.wait_for_everyone() + + # モデルに xformers とか memory efficient attention を組み込む + # train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if args.xformers: + unet.enable_xformers_memory_efficient_attention() + controlnet.enable_xformers_memory_efficient_attention() + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + controlnet.enable_gradient_checkpointing() + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + trainable_params = list(filter(lambda p: p.requires_grad, controlnet.parameters())) + logger.info(f"trainable params count: {len(trainable_params)}") + logger.info( + f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}" + ) + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader)/ accelerator.num_processes/ args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix( + args, optimizer, accelerator.num_processes + ) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + controlnet.to(weight_dtype) + unet.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + controlnet.to(weight_dtype) + unet.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + unet.requires_grad_(False) + text_encoder1.requires_grad_(False) + text_encoder2.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + + # transform DDP after prepare + controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet + + controlnet.train() + + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + text_encoder2.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm( + range(args.max_train_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="steps", + ) + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr( + noise_scheduler + ) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + ( + "controlnet_train" + if args.log_tracker_name is None + else args.log_tracker_name + ), + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + loss_recorder = train_util.LossRecorder() + del train_dataset_group + + # function for saving/removing + def save_model(ckpt_name, model, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) + sai_metadata["modelspec.architecture"] = ( + sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet" + ) + state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(ckpt_file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, ckpt_file, sai_metadata) + else: + torch.save(state_dict, ckpt_file) + + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, + args, + 0, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + controlnet=controlnet, + ) + + # training loop + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(controlnet): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = ( + batch["latents"] + .to(accelerator.device) + .to(dtype=weight_dtype) + ) + else: + # latentに変換 + latents = ( + vae.encode(batch["images"].to(dtype=vae_dtype)) + .latent_dist.sample() + .to(dtype=weight_dtype) + ) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print( + "NaN found in latents, replacing with zeros" + ) + latents = torch.nan_to_num(latents, 0, out=latents) + latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + + if ( + "text_encoder_outputs1_list" not in batch + or batch["text_encoder_outputs1_list"] is None + ): + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + with torch.no_grad(): + # Get the text embedding for conditioning + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = ( + train_util.get_hidden_states_sdxl( + args.max_token_length, + input_ids1, + input_ids2, + tokenizer1, + tokenizer2, + text_encoder1, + text_encoder2, + None if not args.full_fp16 else weight_dtype, + ) + ) + else: + encoder_hidden_states1 = ( + batch["text_encoder_outputs1_list"] + .to(accelerator.device) + .to(weight_dtype) + ) + encoder_hidden_states2 = ( + batch["text_encoder_outputs2_list"] + .to(accelerator.device) + .to(weight_dtype) + ) + pool2 = ( + batch["text_encoder_pool2_list"] + .to(accelerator.device) + .to(weight_dtype) + ) + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings( + # orig_size, crop_size, target_size, accelerator.device + # ).to(weight_dtype) + + embs = torch.cat([orig_size, crop_size, target_size]).to(accelerator.device).to(weight_dtype) #B,6 + # concat embeddings + #vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + vector_embedding_dict = { + "text_embeds": pool2, + "time_ids": embs + } + text_embedding = torch.cat( + [encoder_hidden_states1, encoder_hidden_states2], dim=2 + ).to(weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps, huber_c = ( + train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + ) + + controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + + + with accelerator.autocast(): + down_block_res_samples, mid_block_res_sample = controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=text_embedding, + added_cond_kwargs=vector_embedding_dict, + controlnet_cond=controlnet_image, + return_dict=False, + ) + + # Predict the noise residual + noise_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states=text_embedding, + added_cond_kwargs=vector_embedding_dict, + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), + return_dict=False, + )[0] + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = train_util.conditional_loss( + noise_pred.float(),target.float(),reduction="none",loss_type=args.loss_type,huber_c=huber_c, + ) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss,timesteps,noise_scheduler,args.min_snr_gamma,args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if not args.fused_backward_pass: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + sdxl_train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + controlnet=controlnet, + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name,unwrap_model(controlnet)) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name,unwrap_model(controlnet)) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + sdxl_train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + controlnet=controlnet, + ) + + # end of epoch + + if is_main_process: + controlnet = unwrap_model(controlnet) + + accelerator.end_training() + + if is_main_process and (args.save_state or args.save_state_on_train_end): + train_util.save_state_on_train_end(args, accelerator) + + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model( + ckpt_name, controlnet, force_sync_upload=True + ) + + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="controlnet model name or path / controlnetのモデル名またはパス", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + + return parser + + +if __name__ == "__main__": + # sdxl_original_unet.USE_REENTRANT = False + + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_controlnet.py b/train_controlnet.py index c2945b083..8c7882c8f 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -254,6 +254,7 @@ def __contains__(self, name): accelerator.wait_for_everyone() if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() controlnet.enable_gradient_checkpointing() # 学習に必要なクラスを準備する @@ -304,6 +305,20 @@ def __contains__(self, name): controlnet, optimizer, train_dataloader, lr_scheduler ) + if args.fused_backward_pass: + import library.adafactor_fused + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + unet.requires_grad_(False) text_encoder.requires_grad_(False) unet.to(accelerator.device) @@ -497,13 +512,17 @@ def remove_model(old_ckpt_name): loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = controlnet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + if not args.fused_backward_pass: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: From 8919b31145d38a2a790fae6e8e1c34c205c6794e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 23:07:34 +0900 Subject: [PATCH 161/348] use original ControlNet instead of Diffusers --- gen_img.py | 89 +++- library/sdxl_model_util.py | 2 +- library/sdxl_original_control_net.py | 272 ++++++++++++ library/sdxl_original_unet.py | 14 +- ...controlnet.py => sdxl_train_control_net.py | 390 ++++++++---------- 5 files changed, 528 insertions(+), 239 deletions(-) create mode 100644 library/sdxl_original_control_net.py rename sdxl_train_controlnet.py => sdxl_train_control_net.py (69%) diff --git a/gen_img.py b/gen_img.py index 59bcd5b09..70b3c81ff 100644 --- a/gen_img.py +++ b/gen_img.py @@ -43,8 +43,8 @@ ) from einops import rearrange from tqdm import tqdm -from torchvision import transforms from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor +from accelerate import init_empty_weights import PIL from PIL import Image from PIL.PngImagePlugin import PngInfo @@ -58,6 +58,7 @@ from tools.original_control_net import ControlNetInfo from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel from library.sdxl_original_unet import InferSdxlUNet2DConditionModel +from library.sdxl_original_control_net import SdxlControlNet from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL @@ -352,8 +353,8 @@ def __init__( self.token_replacements_list.append({}) # ControlNet - self.control_nets: List[ControlNetInfo] = [] # only for SD 1.5 - self.control_net_lllites: List[ControlNetLLLite] = [] + self.control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = [] + self.control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない self.gradual_latent: GradualLatent = None @@ -542,7 +543,7 @@ def __call__( else: text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - if self.control_net_lllites: + if self.control_net_lllites or (self.control_nets and self.is_sdxl): # ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う if isinstance(clip_guide_images, PIL.Image.Image): clip_guide_images = [clip_guide_images] @@ -731,7 +732,12 @@ def __call__( num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 if self.control_nets: - guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if not self.is_sdxl: + guided_hints = original_control_net.get_guided_hints( + self.control_nets, num_latent_input, batch_size, clip_guide_images + ) + else: + clip_guide_images = clip_guide_images * 0.5 + 0.5 # [-1, 1] => [0, 1] each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) if self.control_net_lllites: @@ -793,7 +799,7 @@ def __call__( latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # disable ControlNet-LLLite if ratio is set. ControlNet is disabled in ControlNetInfo + # disable ControlNet-LLLite or SDXL ControlNet if ratio is set. ControlNet is disabled in ControlNetInfo if self.control_net_lllites: for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_net_lllites, each_control_net_enabled)): if not enabled or ratio >= 1.0: @@ -802,9 +808,16 @@ def __call__( logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") control_net.set_cond_image(None) each_control_net_enabled[j] = False + if self.control_nets and self.is_sdxl: + for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)): + if not enabled or ratio >= 1.0: + continue + if ratio < i / len(timesteps): + logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + each_control_net_enabled[j] = False # predict the noise residual - if self.control_nets and self.control_net_enabled: + if self.control_nets and self.control_net_enabled and not self.is_sdxl: if regional_network: num_sub_and_neg_prompts = len(text_embeddings) // batch_size text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt @@ -823,6 +836,31 @@ def __call__( text_embeddings, text_emb_last, ).sample + elif self.control_nets: + input_resi_add_list = [] + mid_add_list = [] + for (control_net, _), enbld in zip(self.control_nets, each_control_net_enabled): + if not enbld: + continue + input_resi_add, mid_add = control_net( + latent_model_input, t, text_embeddings, vector_embeddings, clip_guide_images + ) + input_resi_add_list.append(input_resi_add) + mid_add_list.append(mid_add) + if len(input_resi_add_list) == 0: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + else: + if len(input_resi_add_list) > 1: + # get mean of input_resi_add_list and mid_add_list + input_resi_add_mean = [] + for i in range(len(input_resi_add_list[0])): + input_resi_add_mean.append( + torch.mean(torch.stack([input_resi_add_list[j][i] for j in range(len(input_resi_add_list))], dim=0)) + ) + input_resi_add = input_resi_add_mean + mid_add = torch.mean(torch.stack(mid_add_list), dim=0) + + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add) elif self.is_sdxl: noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) else: @@ -1827,16 +1865,37 @@ def __getattr__(self, item): upscaler.to(dtype).to(device) # ControlNetの処理 - control_nets: List[ControlNetInfo] = [] + control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = [] if args.control_net_models: - for i, model in enumerate(args.control_net_models): - prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] - weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] - ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + if not is_sdxl: + for i, model in enumerate(args.control_net_models): + prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) + prep = original_control_net.load_preprocess(prep_type) + control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + else: + for i, model_file in enumerate(args.control_net_models): + multiplier = ( + 1.0 + if not args.control_net_multipliers or len(args.control_net_multipliers) <= i + else args.control_net_multipliers[i] + ) + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + logger.info(f"loading SDXL ControlNet: {model_file}") + from safetensors.torch import load_file + + state_dict = load_file(model_file) - ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) - prep = original_control_net.load_preprocess(prep_type) - control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + logger.info(f"Initalizing SDXL ControlNet with multiplier: {multiplier}") + with init_empty_weights(): + control_net = SdxlControlNet(multiplier=multiplier) + control_net.load_state_dict(state_dict) + control_net.to(dtype).to(device) + control_nets.append((control_net, ratio)) control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] if args.control_net_lllite_models: diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 4fad78a1c..0466c1fa5 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -8,7 +8,7 @@ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from library import model_util from library import sdxl_original_unet -from .utils import setup_logging +from library.utils import setup_logging setup_logging() import logging diff --git a/library/sdxl_original_control_net.py b/library/sdxl_original_control_net.py new file mode 100644 index 000000000..3af45f4db --- /dev/null +++ b/library/sdxl_original_control_net.py @@ -0,0 +1,272 @@ +# some parts are modified from Diffusers library (Apache License 2.0) + +import math +from types import SimpleNamespace +from typing import Any, Optional +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from einops import rearrange +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sdxl_original_unet +from library.sdxl_model_util import convert_sdxl_unet_state_dict_to_diffusers, convert_diffusers_unet_state_dict_to_sdxl + + +class ControlNetConditioningEmbedding(nn.Module): + def __init__(self): + super().__init__() + + dims = [16, 32, 96, 256] + + self.conv_in = nn.Conv2d(3, dims[0], kernel_size=3, padding=1) + self.blocks = nn.ModuleList([]) + + for i in range(len(dims) - 1): + channel_in = dims[i] + channel_out = dims[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = nn.Conv2d(dims[-1], 320, kernel_size=3, padding=1) + nn.init.zeros_(self.conv_out.weight) # zero module weight + nn.init.zeros_(self.conv_out.bias) # zero module bias + + def forward(self, x): + x = self.conv_in(x) + x = F.silu(x) + for block in self.blocks: + x = block(x) + x = F.silu(x) + x = self.conv_out(x) + return x + + +class SdxlControlNet(sdxl_original_unet.SdxlUNet2DConditionModel): + def __init__(self, multiplier: Optional[float] = None, **kwargs): + super().__init__(**kwargs) + self.multiplier = multiplier + + # remove unet layers + self.output_blocks = nn.ModuleList([]) + del self.out + + self.controlnet_cond_embedding = ControlNetConditioningEmbedding() + + dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280] + self.controlnet_down_blocks = nn.ModuleList([]) + for dim in dims: + self.controlnet_down_blocks.append(nn.Conv2d(dim, dim, kernel_size=1)) + nn.init.zeros_(self.controlnet_down_blocks[-1].weight) # zero module weight + nn.init.zeros_(self.controlnet_down_blocks[-1].bias) # zero module bias + + self.controlnet_mid_block = nn.Conv2d(1280, 1280, kernel_size=1) + nn.init.zeros_(self.controlnet_mid_block.weight) # zero module weight + nn.init.zeros_(self.controlnet_mid_block.bias) # zero module bias + + def init_from_unet(self, unet: sdxl_original_unet.SdxlUNet2DConditionModel): + unet_sd = unet.state_dict() + unet_sd = {k: v for k, v in unet_sd.items() if not k.startswith("out")} + sd = super().state_dict() + sd.update(unet_sd) + info = super().load_state_dict(sd, strict=True, assign=True) + return info + + def load_state_dict(self, state_dict: dict, strict: bool = True, assign: bool = True) -> Any: + # convert state_dict to SAI format + unet_sd = {} + for k in list(state_dict.keys()): + if not k.startswith("controlnet_"): + unet_sd[k] = state_dict.pop(k) + unet_sd = convert_diffusers_unet_state_dict_to_sdxl(unet_sd) + state_dict.update(unet_sd) + super().load_state_dict(state_dict, strict=strict, assign=assign) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + # convert state_dict to Diffusers format + state_dict = super().state_dict(destination, prefix, keep_vars) + control_net_sd = {} + for k in list(state_dict.keys()): + if k.startswith("controlnet_"): + control_net_sd[k] = state_dict.pop(k) + state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict) + state_dict.update(control_net_sd) + return state_dict + + def forward( + self, + x: torch.Tensor, + timesteps: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + cond_image: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + if isinstance(layer, sdxl_original_unet.ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, sdxl_original_unet.Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + h = x + multiplier = self.multiplier if self.multiplier is not None else 1.0 + hs = [] + for i, module in enumerate(self.input_blocks): + h = call_module(module, h, emb, context) + if i == 0: + h = self.controlnet_cond_embedding(cond_image) + h + hs.append(self.controlnet_down_blocks[i](h) * multiplier) + + h = call_module(self.middle_block, h, emb, context) + h = self.controlnet_mid_block(h) * multiplier + + return hs, h + + +class SdxlControlledUNet(sdxl_original_unet.SdxlUNet2DConditionModel): + """ + This class is for training purpose only. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs): + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + hs = [] + t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + if isinstance(layer, sdxl_original_unet.ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, sdxl_original_unet.Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + h = x + for module in self.input_blocks: + h = call_module(module, h, emb, context) + hs.append(h) + + h = call_module(self.middle_block, h, emb, context) + h = h + mid_add + + for module in self.output_blocks: + resi = hs.pop() + input_resi_add.pop() + h = torch.cat([h, resi], dim=1) + h = call_module(module, h, emb, context) + + h = h.type(x.dtype) + h = call_module(self.out, h, emb, context) + + return h + + +if __name__ == "__main__": + import time + + logger.info("create unet") + unet = SdxlControlledUNet() + unet.to("cuda", torch.bfloat16) + unet.set_use_sdpa(True) + unet.set_gradient_checkpointing(True) + unet.train() + + logger.info("create control_net") + control_net = SdxlControlNet() + control_net.to("cuda") + control_net.set_use_sdpa(True) + control_net.set_gradient_checkpointing(True) + control_net.train() + + logger.info("Initialize control_net from unet") + control_net.init_from_unet(unet) + + unet.requires_grad_(False) + control_net.requires_grad_(True) + + # 使用メモリ量確認用の疑似学習ループ + logger.info("preparing optimizer") + + # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working + + import bitsandbytes + + optimizer = bitsandbytes.adam.Adam8bit(control_net.parameters(), lr=1e-3) # not working + # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + + # import transformers + # optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2 + + scaler = torch.cuda.amp.GradScaler(enabled=True) + + logger.info("start training") + steps = 10 + batch_size = 1 + + for step in range(steps): + logger.info(f"step {step}") + if step == 1: + time_start = time.perf_counter() + + x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024 + t = torch.randint(low=0, high=1000, size=(batch_size,), device="cuda") + txt = torch.randn(batch_size, 77, 2048).cuda() + vector = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() + cond_img = torch.rand(batch_size, 3, 1024, 1024).cuda() + + with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): + input_resi_add, mid_add = control_net(x, t, txt, vector, cond_img) + output = unet(x, t, txt, vector, input_resi_add, mid_add) + target = torch.randn_like(output) + loss = torch.nn.functional.mse_loss(output, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + + time_end = time.perf_counter() + logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") + + logger.info("finish training") + sd = control_net.state_dict() + + from safetensors.torch import save_file + + save_file(sd, r"E:\Work\SD\Tmp\sdxl\ctrl\control_net.safetensors") diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 17c345a89..0aa07d0d6 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -30,7 +30,7 @@ from torch import nn from torch.nn import functional as F from einops import rearrange -from .utils import setup_logging +from library.utils import setup_logging setup_logging() import logging @@ -1156,9 +1156,9 @@ def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_ti self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 self.ds_ratio = ds_ratio - def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs): r""" - current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink. + current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink and ControlNet. """ _self = self.delegate @@ -1209,6 +1209,8 @@ def call_module(module, h, emb, context): hs.append(h) h = call_module(_self.middle_block, h, emb, context) + if mid_add is not None: + h = h + mid_add for module in _self.output_blocks: # Deep Shrink @@ -1217,7 +1219,11 @@ def call_module(module, h, emb, context): # print("upsample", h.shape, hs[-1].shape) h = resize_like(h, hs[-1]) - h = torch.cat([h, hs.pop()], dim=1) + resi = hs.pop() + if input_resi_add is not None: + resi = resi + input_resi_add.pop() + + h = torch.cat([h, resi], dim=1) h = call_module(module, h, emb, context) # Deep Shrink: in case of depth 0 diff --git a/sdxl_train_controlnet.py b/sdxl_train_control_net.py similarity index 69% rename from sdxl_train_controlnet.py rename to sdxl_train_control_net.py index 00026d2cc..74dcff2af 100644 --- a/sdxl_train_controlnet.py +++ b/sdxl_train_control_net.py @@ -14,6 +14,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed +from accelerate import init_empty_weights from diffusers import DDPMScheduler, ControlNetModel from diffusers.utils.torch_utils import is_compiled_module from safetensors.torch import load_file @@ -23,6 +24,9 @@ sdxl_model_util, sdxl_original_unet, sdxl_train_util, + strategy_base, + strategy_sd, + strategy_sdxl, ) import library.model_util as model_util @@ -41,6 +45,7 @@ scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, ) +from library.sdxl_original_control_net import SdxlControlNet, SdxlControlledUNet from library.utils import setup_logging, add_logging_arguments setup_logging() @@ -58,10 +63,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche } if args.optimizer_type.lower().startswith("DAdapt".lower()): - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[-1].param_groups[0]["d"] - * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - ) + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] return logs @@ -79,7 +81,14 @@ def train(args): args.seed = random.randint(0, 2**32) set_seed(args.seed) - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) @@ -106,17 +115,18 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) - ds_for_collator = (train_dataset_group if args.max_data_loader_n_workers == 0 else None) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) train_dataset_group.verify_bucket_reso_steps(32) if args.debug_dataset: + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: @@ -162,86 +172,99 @@ def unwrap_model(model): unet, logit_scale, ckpt_info, - ) = sdxl_train_util.load_target_model( - args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype - ) + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) + + unet.to(accelerator.device) # reduce main memory usage + + # convert U-Net to Controlled U-Net + logger.info("convert U-Net to Controlled U-Net") + unet_sd = unet.state_dict() + with init_empty_weights(): + unet = SdxlControlledUNet() + unet.load_state_dict(unet_sd, strict=True, assign=True) + del unet_sd + + # make control net + logger.info("make ControlNet") + if args.controlnet_model_path: + with init_empty_weights(): + control_net = SdxlControlNet() + + logger.info(f"load ControlNet from {args.controlnet_model_path}") + filename = args.controlnet_model_path + if os.path.splitext(filename)[1] == ".safetensors": + state_dict = load_file(filename) + else: + state_dict = torch.load(filename) + info = control_net.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"ControlNet loaded from {filename}: {info}") + else: + control_net = SdxlControlNet() - # convert U-Net - with torch.no_grad(): - du_unet_sd = sdxl_model_util.convert_sdxl_unet_state_dict_to_diffusers(unet.state_dict()) - unet.to("cpu") - clean_memory_on_device(accelerator.device) - del unet - unet = sdxl_model_util.UNet2DConditionModel(**sdxl_model_util.DIFFUSERS_SDXL_UNET_CONFIG) - unet.load_state_dict(du_unet_sd) - - controlnet = ControlNetModel.from_unet(unet) - - if args.controlnet_model_name_or_path: - filename = args.controlnet_model_name_or_path - if os.path.isfile(filename): - if os.path.splitext(filename)[1] == ".safetensors": - state_dict = load_file(filename) - else: - state_dict = torch.load(filename) - state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) - controlnet.load_state_dict(state_dict) - elif os.path.isdir(filename): - controlnet = ControlNetModel.from_pretrained(filename) + logger.info("initialize ControlNet from U-Net") + info = control_net.init_from_unet(unet) + logger.info(f"ControlNet initialized from U-Net: {info}") # 学習を準備する if cache_latents: vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + accelerator.wait_for_everyone() # モデルに xformers とか memory efficient attention を組み込む # train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) if args.xformers: - unet.enable_xformers_memory_efficient_attention() - controlnet.enable_xformers_memory_efficient_attention() + unet.set_use_memory_efficient_attention(True, False) + control_net.set_use_memory_efficient_attention(True, False) + elif args.sdpa: + unet.set_use_sdpa(True) + control_net.set_use_sdpa(True) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - controlnet.enable_gradient_checkpointing() + control_net.enable_gradient_checkpointing() # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - trainable_params = list(filter(lambda p: p.requires_grad, controlnet.parameters())) + trainable_params = list(control_net.parameters()) + # for p in trainable_params: + # p.requires_grad = True logger.info(f"trainable params count: {len(trainable_params)}") - logger.info( - f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}" - ) + logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -257,7 +280,7 @@ def unwrap_model(model): # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader)/ accelerator.num_processes/ args.gradient_accumulation_steps + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) accelerator.print( f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" @@ -267,9 +290,7 @@ def unwrap_model(model): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix( - args, optimizer, accelerator.num_processes - ) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする if args.full_fp16: @@ -277,19 +298,17 @@ def unwrap_model(model): args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" accelerator.print("enable full fp16 training.") - controlnet.to(weight_dtype) - unet.to(weight_dtype) + control_net.to(weight_dtype) elif args.full_bf16: assert ( args.mixed_precision == "bf16" ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" accelerator.print("enable full bf16 training.") - controlnet.to(weight_dtype) - unet.to(weight_dtype) + control_net.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - controlnet, optimizer, train_dataloader, lr_scheduler + control_net, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + control_net, optimizer, train_dataloader, lr_scheduler ) if args.fused_backward_pass: @@ -314,10 +333,8 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): text_encoder2.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) - # transform DDP after prepare - controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet - - controlnet.train() + unet.eval() + control_net.train() # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: @@ -362,26 +379,15 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - progress_bar = tqdm( - range(args.max_train_steps), - smoothing=0, - disable=not accelerator.is_local_main_process, - desc="steps", - ) + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - clip_sample=False, + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr( - noise_scheduler - ) + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) if accelerator.is_main_process: init_kwargs = {} @@ -390,11 +396,7 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - ( - "controlnet_train" - if args.log_tracker_name is None - else args.log_tracker_name - ), + ("sdxl_control_net_train" if args.log_tracker_name is None else args.log_tracker_name), config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs, ) @@ -409,10 +411,8 @@ def save_model(ckpt_name, model, force_sync_upload=False): accelerator.print(f"\nsaving checkpoint: {ckpt_file}") sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) - sai_metadata["modelspec.architecture"] = ( - sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet" - ) - state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) + sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet" + state_dict = model.state_dict() if save_dtype is not None: for key in list(state_dict.keys()): @@ -436,19 +436,19 @@ def remove_model(old_ckpt_name): accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) - # For --sample_at_first - sdxl_train_util.sample_images( - accelerator, - args, - 0, - global_step, - accelerator.device, - vae, - [tokenizer1, tokenizer2], - [text_encoder1, text_encoder2], - unet, - controlnet=controlnet, - ) + # # For --sample_at_first + # sdxl_train_util.sample_images( + # accelerator, + # args, + # 0, + # global_step, + # accelerator.device, + # vae, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # unet, + # controlnet=control_net, + # ) # training loop for epoch in range(num_train_epochs): @@ -457,121 +457,63 @@ def remove_model(old_ckpt_name): for step, batch in enumerate(train_dataloader): current_step.value = global_step - with accelerator.accumulate(controlnet): + with accelerator.accumulate(control_net): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = ( - batch["latents"] - .to(accelerator.device) - .to(dtype=weight_dtype) - ) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 - latents = ( - vae.encode(batch["images"].to(dtype=vae_dtype)) - .latent_dist.sample() - .to(dtype=weight_dtype) - ) + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): - accelerator.print( - "NaN found in latents, replacing with zeros" - ) + accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - if ( - "text_encoder_outputs1_list" not in batch - or batch["text_encoder_outputs1_list"] is None - ): - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] with torch.no_grad(): - # Get the text embedding for conditioning input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = ( - train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, - ) + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] ) - else: - encoder_hidden_states1 = ( - batch["text_encoder_outputs1_list"] - .to(accelerator.device) - .to(weight_dtype) - ) - encoder_hidden_states2 = ( - batch["text_encoder_outputs2_list"] - .to(accelerator.device) - .to(weight_dtype) - ) - pool2 = ( - batch["text_encoder_pool2_list"] - .to(accelerator.device) - .to(weight_dtype) - ) + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) # get size embeddings orig_size = batch["original_sizes_hw"] crop_size = batch["crop_top_lefts"] target_size = batch["target_sizes_hw"] - # embs = sdxl_train_util.get_size_embeddings( - # orig_size, crop_size, target_size, accelerator.device - # ).to(weight_dtype) + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) - embs = torch.cat([orig_size, crop_size, target_size]).to(accelerator.device).to(weight_dtype) #B,6 # concat embeddings - #vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) - vector_embedding_dict = { - "text_embeds": pool2, - "time_ids": embs - } - text_embedding = torch.cat( - [encoder_hidden_states1, encoder_hidden_states2], dim=2 - ).to(weight_dtype) + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = ( - train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents ) controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) - with accelerator.autocast(): - down_block_res_samples, mid_block_res_sample = controlnet( - noisy_latents, - timesteps, - encoder_hidden_states=text_embedding, - added_cond_kwargs=vector_embedding_dict, - controlnet_cond=controlnet_image, - return_dict=False, + input_resi_add, mid_add = control_net( + noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image ) - - # Predict the noise residual - noise_pred = unet( - noisy_latents, - timesteps, - encoder_hidden_states=text_embedding, - added_cond_kwargs=vector_embedding_dict, - down_block_additional_residuals=[ - sample.to(dtype=weight_dtype) for sample in down_block_res_samples - ], - mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), - return_dict=False, - )[0] + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, input_resi_add, mid_add) if args.v_parameterization: # v-parameterization training @@ -580,7 +522,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss( - noise_pred.float(),target.float(),reduction="none",loss_type=args.loss_type,huber_c=huber_c, + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) loss = loss.mean([1, 2, 3]) @@ -588,7 +530,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss,timesteps,noise_scheduler,args.min_snr_gamma,args.v_parameterization) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: @@ -601,7 +543,7 @@ def remove_model(old_ckpt_name): accelerator.backward(loss) if not args.fused_backward_pass: if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = controlnet.parameters() + params_to_clip = control_net.parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -616,25 +558,25 @@ def remove_model(old_ckpt_name): progress_bar.update(1) global_step += 1 - sdxl_train_util.sample_images( - accelerator, - args, - None, - global_step, - accelerator.device, - vae, - [tokenizer1, tokenizer2], - [text_encoder1, text_encoder2], - unet, - controlnet=controlnet, - ) + # sdxl_train_util.sample_images( + # accelerator, + # args, + # None, + # global_step, + # accelerator.device, + # vae, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # unet, + # controlnet=control_net, + # ) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name,unwrap_model(controlnet)) + save_model(ckpt_name, unwrap_model(control_net)) if args.save_state: train_util.save_and_remove_state_stepwise(args, accelerator, global_step) @@ -650,14 +592,14 @@ def remove_model(old_ckpt_name): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) @@ -668,7 +610,7 @@ def remove_model(old_ckpt_name): saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name,unwrap_model(controlnet)) + save_model(ckpt_name, unwrap_model(control_net)) remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) if remove_epoch_no is not None: @@ -688,13 +630,13 @@ def remove_model(old_ckpt_name): [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet, - controlnet=controlnet, + controlnet=control_net, ) # end of epoch if is_main_process: - controlnet = unwrap_model(controlnet) + control_net = unwrap_model(control_net) accelerator.end_training() @@ -703,9 +645,7 @@ def remove_model(old_ckpt_name): if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model( - ckpt_name, controlnet, force_sync_upload=True - ) + save_model(ckpt_name, control_net, force_sync_upload=True) logger.info("model saved.") @@ -717,26 +657,38 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) - train_util.add_masked_loss_arguments(parser) + # train_util.add_masked_loss_arguments(parser) deepspeed_utils.add_deepspeed_arguments(parser) - train_util.add_sd_saving_arguments(parser) + # train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) sdxl_train_util.add_sdxl_training_arguments(parser) parser.add_argument( - "--controlnet_model_name_or_path", + "--controlnet_model_path", type=str, default=None, help="controlnet model name or path / controlnetのモデル名またはパス", ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) parser.add_argument( "--no_half_vae", action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) - return parser From 0243c65877a7700ffab1e782690f26080a0deadc Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 23:09:56 +0900 Subject: [PATCH 162/348] fix typo --- gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gen_img.py b/gen_img.py index 70b3c81ff..421d5c0b9 100644 --- a/gen_img.py +++ b/gen_img.py @@ -1890,7 +1890,7 @@ def __getattr__(self, item): state_dict = load_file(model_file) - logger.info(f"Initalizing SDXL ControlNet with multiplier: {multiplier}") + logger.info(f"Initializing SDXL ControlNet with multiplier: {multiplier}") with init_empty_weights(): control_net = SdxlControlNet(multiplier=multiplier) control_net.load_state_dict(state_dict) From 793999d116638548fc16579b712f44456ee3034e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 30 Sep 2024 23:39:32 +0900 Subject: [PATCH 163/348] sample generation in SDXL ControlNet training --- library/sdxl_lpw_stable_diffusion.py | 168 +++++++---------------- library/strategy_base.py | 192 ++++++++++++++++++++++++++- library/strategy_sdxl.py | 39 +++++- library/train_util.py | 35 +++-- sdxl_train_control_net.py | 55 ++++---- 5 files changed, 323 insertions(+), 166 deletions(-) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index 03b182566..9196eb0f2 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -13,12 +13,20 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from diffusers import SchedulerMixin, StableDiffusionPipeline -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +from diffusers.models import AutoencoderKL +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.utils import logging from PIL import Image -from library import sdxl_model_util, sdxl_train_util, train_util +from library import ( + sdxl_model_util, + sdxl_train_util, + strategy_base, + strategy_sdxl, + train_util, + sdxl_original_unet, + sdxl_original_control_net, +) try: @@ -537,7 +545,7 @@ def __init__( vae: AutoencoderKL, text_encoder: List[CLIPTextModel], tokenizer: List[CLIPTokenizer], - unet: UNet2DConditionModel, + unet: Union[sdxl_original_unet.SdxlUNet2DConditionModel, sdxl_original_control_net.SdxlControlledUNet], scheduler: SchedulerMixin, # clip_skip: int, safety_checker: StableDiffusionSafetyChecker, @@ -594,74 +602,6 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - is_sdxl_text_encoder2, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `list(int)`): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - is_sdxl_text_encoder2=is_sdxl_text_encoder2, - ) - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ?? - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - if text_pool is not None: - text_pool = text_pool.repeat(1, num_images_per_prompt) - text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1) - - if do_classifier_free_guidance: - bs_embed, seq_len, _ = uncond_embeddings.shape - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - if uncond_pool is not None: - uncond_pool = uncond_pool.repeat(1, num_images_per_prompt) - uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1) - - return text_embeddings, text_pool, uncond_embeddings, uncond_pool - - return text_embeddings, text_pool, None, None - def check_inputs(self, prompt, height, width, strength, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") @@ -792,7 +732,7 @@ def __call__( max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, - controlnet=None, + controlnet: sdxl_original_control_net.SdxlControlNet = None, controlnet_image=None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, @@ -896,32 +836,24 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - # 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す - # To simplify the implementation, switch the tokenzer/text encoder and call it twice - text_embeddings_list = [] - text_pool = None - uncond_embeddings_list = [] - uncond_pool = None - for i in range(len(self.tokenizers)): - self.tokenizer = self.tokenizers[i] - self.text_encoder = self.text_encoders[i] - - text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - is_sdxl_text_encoder2=i == 1, - ) - text_embeddings_list.append(text_embeddings) - uncond_embeddings_list.append(uncond_embeddings) + tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy: strategy_sdxl.SdxlTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - if tp1 is not None: - text_pool = tp1 - if up1 is not None: - uncond_pool = up1 + text_input_ids, text_weights = tokenize_strategy.tokenize_with_weights(prompt) + hidden_states_1, hidden_states_2, text_pool = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, self.text_encoders, text_input_ids, text_weights + ) + text_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1) + + if do_classifier_free_guidance: + input_ids, weights = tokenize_strategy.tokenize_with_weights(negative_prompt or "") + hidden_states_1, hidden_states_2, uncond_pool = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, self.text_encoders, input_ids, weights + ) + uncond_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1) + else: + uncond_embeddings = None + uncond_pool = None unet_dtype = self.unet.dtype dtype = unet_dtype @@ -970,23 +902,23 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # create size embs and concat embeddings for SDXL - orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype) + orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(device, dtype) crop_size = torch.zeros_like(orig_size) target_size = orig_size - embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype) + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(device, dtype) # make conditionings + text_pool = text_pool.to(device, dtype) if do_classifier_free_guidance: - text_embeddings = torch.cat(text_embeddings_list, dim=2) - uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2) - text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype) + text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(device, dtype) - cond_vector = torch.cat([text_pool, embs], dim=1) - uncond_vector = torch.cat([uncond_pool, embs], dim=1) - vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype) + uncond_pool = uncond_pool.to(device, dtype) + cond_vector = torch.cat([text_pool, embs], dim=1).to(dtype) + uncond_vector = torch.cat([uncond_pool, embs], dim=1).to(dtype) + vector_embedding = torch.cat([uncond_vector, cond_vector]) else: - text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype) - vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype) + text_embedding = text_embeddings.to(device, dtype) + vector_embedding = torch.cat([text_pool, embs], dim=1) # 8. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): @@ -994,22 +926,14 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - unet_additional_args = {} - if controlnet is not None: - down_block_res_samples, mid_block_res_sample = controlnet( - latent_model_input, - t, - encoder_hidden_states=text_embeddings, - controlnet_cond=controlnet_image, - conditioning_scale=1.0, - guess_mode=False, - return_dict=False, - ) - unet_additional_args["down_block_additional_residuals"] = down_block_res_samples - unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample + # FIXME SD1 ControlNet is not working # predict the noise residual - noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) + if controlnet is not None: + input_resi_add, mid_add = controlnet(latent_model_input, t, text_embedding, vector_embedding, controlnet_image) + noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding, input_resi_add, mid_add) + else: + noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training # perform guidance diff --git a/library/strategy_base.py b/library/strategy_base.py index e7d3a97ef..10820afa1 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -1,6 +1,7 @@ # base class for platform strategies. this file defines the interface for strategies import os +import re from typing import Any, List, Optional, Tuple, Union import numpy as np @@ -22,6 +23,24 @@ class TokenizeStrategy: _strategy = None # strategy instance: actual strategy class + _re_attention = re.compile( + r"""\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, + ) + @classmethod def set_strategy(cls, strategy): if cls._strategy is not None: @@ -54,7 +73,151 @@ def _load_tokenizer( def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: raise NotImplementedError - def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor: + def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + raise NotImplementedError + + def _get_weighted_input_ids( + self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + max_length includes starting and ending tokens. + """ + + def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in TokenizeStrategy._re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + def get_prompts_with_weights(text: str, max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token. + + No padding, starting or ending token is included. + """ + truncated = False + + texts_and_weights = parse_prompt_attention(text) + tokens = [] + weights = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + tokens += token + # copy the weight by length of token + weights += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(tokens) > max_length: + truncated = True + break + # truncate + if len(tokens) > max_length: + truncated = True + tokens = tokens[:max_length] + weights = weights[:max_length] + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens)) + weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights)) + return tokens, weights + + if max_length is None: + max_length = tokenizer.model_max_length + + tokens, weights = get_prompts_with_weights(text, max_length - 2) + tokens, weights = pad_tokens_and_weights( + tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id + ) + return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0) + + def _get_input_ids( + self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False + ) -> torch.Tensor: """ for SD1.5/2.0/SDXL TODO support batch input @@ -62,7 +225,10 @@ def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Option if max_length is None: max_length = tokenizer.model_max_length - 2 - input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids + if weighted: + input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length) + else: + input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids if max_length > tokenizer.model_max_length: input_ids = input_ids.squeeze(0) @@ -101,6 +267,17 @@ def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Option iids_list.append(ids_chunk) input_ids = torch.stack(iids_list) # 3,77 + + if weighted: + weights = weights.squeeze(0) + new_weights = torch.ones(input_ids.shape) + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): + b = i // (tokenizer.model_max_length - 2) + new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2] + weights = new_weights + + if weighted: + return input_ids, weights return input_ids @@ -126,6 +303,17 @@ def encode_tokens( :return: list of output embeddings for each architecture """ raise NotImplementedError + + def encode_tokens_with_weights( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Encode tokens into embeddings and outputs. + :param tokens: list of token tensors for each TextModel + :param weights: list of weight tensors for each TextModel + :return: list of output embeddings for each architecture + """ + raise NotImplementedError class TextEncoderOutputsCachingStrategy: diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index 3eb0ab6f6..b48e6d55a 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -37,6 +37,22 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0), ) + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + text = [text] if isinstance(text, str) else text + tokens1_list, tokens2_list = [], [] + weights1_list, weights2_list = [], [] + for t in text: + tokens1, weights1 = self._get_weighted_input_ids(self.tokenizer1, t, self.max_length) + tokens2, weights2 = self._get_weighted_input_ids(self.tokenizer2, t, self.max_length) + tokens1_list.append(tokens1) + tokens2_list.append(tokens2) + weights1_list.append(weights1) + weights2_list.append(weights2) + return (torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)), ( + torch.stack(weights1_list, dim=0), + torch.stack(weights2_list, dim=0), + ) + class SdxlTextEncodingStrategy(TextEncodingStrategy): def __init__(self) -> None: @@ -98,7 +114,10 @@ def _get_hidden_states_sdxl( ): # input_ids: b,n,77 -> b*n, 77 b_size = input_ids1.size()[0] - max_token_length = input_ids1.size()[1] * input_ids1.size()[2] + if input_ids1.size()[1] == 1: + max_token_length = None + else: + max_token_length = input_ids1.size()[1] * input_ids1.size()[2] input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 input_ids1 = input_ids1.to(text_encoder1.device) @@ -172,6 +191,24 @@ def encode_tokens( ) return [hidden_states1, hidden_states2, pool2] + def encode_tokens_with_weights( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor] + ) -> List[torch.Tensor]: + hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens) + + # apply weights + if weights[0].shape[1] == 1: # no max_token_length + # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) + hidden_states1 = hidden_states1 * weights[0].squeeze(1).unsqueeze(2) + hidden_states2 = hidden_states2 * weights[1].squeeze(1).unsqueeze(2) + else: + # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) + for weight, hidden_states in zip(weights, [hidden_states1, hidden_states2]): + for i in range(weight.shape[1]): + hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[:, i, 1:-1] + + return [hidden_states1, hidden_states2, pool2] + class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" diff --git a/library/train_util.py b/library/train_util.py index 293fc05ad..b559616f2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -74,6 +74,7 @@ import cv2 import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline +from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec @@ -3581,7 +3582,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: # available backends: # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 # https://pytorch.org/docs/stable/torch.compiler.html - choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt", "tensort", "ipex", "tvm"], + choices=[ + "eager", + "aot_eager", + "inductor", + "aot_ts_nvfuser", + "nvprims_nvfuser", + "cudagraphs", + "ofi", + "fx2trt", + "onnxrt", + "tensort", + "ipex", + "tvm", + ], help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") @@ -5850,8 +5864,8 @@ def sample_images_common( pipe_class, accelerator: Accelerator, args: argparse.Namespace, - epoch, - steps, + epoch: int, + steps: int, device, vae, tokenizer, @@ -5910,11 +5924,7 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - # schedulers: dict = {} cannot find where this is used - default_scheduler = get_my_scheduler( - sample_sampler=args.sample_sampler, - v_parameterization=args.v_parameterization, - ) + default_scheduler = get_my_scheduler(sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization) pipeline = pipe_class( text_encoder=text_encoder, @@ -5975,21 +5985,18 @@ def sample_images_common( # clear pipeline and cache to reduce vram usage del pipeline - # I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here. - # with torch.cuda.device(torch.cuda.current_device()): - # torch.cuda.empty_cache() - clean_memory_on_device(accelerator.device) - torch.set_rng_state(rng_state) if torch.cuda.is_available() and cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, - pipeline, + pipeline: Union[StableDiffusionLongPromptWeightingPipeline, SdxlStableDiffusionLongPromptWeightingPipeline], save_dir, prompt_dict, epoch, diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 74dcff2af..583a27dcc 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -83,6 +83,7 @@ def train(args): tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizer1, tokenizer2 = tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2 # this is used for sampling images # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( @@ -436,19 +437,19 @@ def remove_model(old_ckpt_name): accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) - # # For --sample_at_first - # sdxl_train_util.sample_images( - # accelerator, - # args, - # 0, - # global_step, - # accelerator.device, - # vae, - # [tokenizer1, tokenizer2], - # [text_encoder1, text_encoder2], - # unet, - # controlnet=control_net, - # ) + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, + args, + 0, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], + unet, + controlnet=control_net, + ) # training loop for epoch in range(num_train_epochs): @@ -484,7 +485,7 @@ def remove_model(old_ckpt_name): input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( - tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] + tokenize_strategy, [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], [input_ids1, input_ids2] ) if args.full_fp16: encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) @@ -558,18 +559,18 @@ def remove_model(old_ckpt_name): progress_bar.update(1) global_step += 1 - # sdxl_train_util.sample_images( - # accelerator, - # args, - # None, - # global_step, - # accelerator.device, - # vae, - # [tokenizer1, tokenizer2], - # [text_encoder1, text_encoder2], - # unet, - # controlnet=control_net, - # ) + sdxl_train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], + unet, + controlnet=control_net, + ) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -628,7 +629,7 @@ def remove_model(old_ckpt_name): accelerator.device, vae, [tokenizer1, tokenizer2], - [text_encoder1, text_encoder2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], unet, controlnet=control_net, ) From c2440f9e53239e7e5dee426f611800d3e38a7f0e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 3 Oct 2024 21:32:21 +0900 Subject: [PATCH 164/348] fix cond image normlization, add independent LR for control --- library/sdxl_train_util.py | 3 ++- library/train_util.py | 20 +++++++++++++++++++- sdxl_train_control_net.py | 30 +++++++++++++++++++++++++----- 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f009b5779..aaf77b8dd 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -12,7 +12,6 @@ from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet -from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline from .utils import setup_logging setup_logging() @@ -378,4 +377,6 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin def sample_images(*args, **kwargs): + from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline + return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) diff --git a/library/train_util.py b/library/train_util.py index b559616f2..07c253a0e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -31,6 +31,7 @@ import subprocess from io import BytesIO import toml +# from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm @@ -912,6 +913,23 @@ def make_buckets(self): if info.image_size is None: info.image_size = self.get_image_size(info.absolute_path) + # # run in parallel + # max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes) + # with ThreadPoolExecutor(max_workers) as executor: + # futures = [] + # for info in tqdm(self.image_data.values(), desc="loading image sizes"): + # if info.image_size is None: + # def get_and_set_image_size(info): + # info.image_size = self.get_image_size(info.absolute_path) + # futures.append(executor.submit(get_and_set_image_size, info)) + # # consume futures to reduce memory usage and prevent Ctrl-C hang + # if len(futures) >= max_workers: + # for future in futures: + # future.result() + # futures = [] + # for future in futures: + # future.result() + if self.enable_bucket: logger.info("make buckets") else: @@ -1826,7 +1844,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] missing_captions = [] - for img_path in img_paths: + for img_path in tqdm(img_paths, desc="read caption"): cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) if cap_for_img is None and subset.class_tokens is None: logger.warning( diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 583a27dcc..b902cda69 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -253,11 +253,20 @@ def unwrap_model(model): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - trainable_params = list(control_net.parameters()) - # for p in trainable_params: - # p.requires_grad = True - logger.info(f"trainable params count: {len(trainable_params)}") - logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + trainable_params = [] + ctrlnet_params = [] + unet_params = [] + for name, param in control_net.named_parameters(): + if name.startswith("controlnet_"): + ctrlnet_params.append(param) + else: + unet_params.append(param) + trainable_params.append({"params": ctrlnet_params, "lr": args.control_net_lr}) + trainable_params.append({"params": unet_params, "lr": args.learning_rate}) + all_params = ctrlnet_params + unet_params + + logger.info(f"trainable params count: {len(all_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in all_params)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -456,6 +465,8 @@ def remove_model(old_ckpt_name): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + control_net.train() + for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(control_net): @@ -510,6 +521,9 @@ def remove_model(old_ckpt_name): controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + # '-1 to +1' to '0 to 1' + controlnet_image = (controlnet_image + 1) / 2 + with accelerator.autocast(): input_resi_add, mid_add = control_net( noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image @@ -690,6 +704,12 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--control_net_lr", + type=float, + default=1e-4, + help="learning rate for controlnet / controlnetの学習率", + ) return parser From 3028027e074c891f33d45fff27068b490a408329 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Fri, 4 Oct 2024 16:41:41 +0800 Subject: [PATCH 165/348] Update train_network.py --- train_network.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/train_network.py b/train_network.py index e10c17c0c..c0239a6da 100644 --- a/train_network.py +++ b/train_network.py @@ -1034,26 +1034,26 @@ def remove_model(old_ckpt_name): logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) - + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break From dece2c388f1c39e7baca201b4bf4e61d9f67a219 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Fri, 4 Oct 2024 16:43:07 +0800 Subject: [PATCH 166/348] Update train_db.py --- train_db.py | 164 ++++++++++++++++++++++++++-------------------------- 1 file changed, 82 insertions(+), 82 deletions(-) diff --git a/train_db.py b/train_db.py index 800a157bf..2c17e521f 100644 --- a/train_db.py +++ b/train_db.py @@ -46,67 +46,67 @@ # perlin_noise, def process_val_batch(*training_models, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args): - total_loss = 0.0 - timesteps_list = [10, 350, 500, 650, 990] - - with accelerator.accumulate(*training_models): - with torch.no_grad(): - # latentに変換 - if cache_latents: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) - else: - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - with torch.set_grad_enabled(False), accelerator.autocast(): - if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - - for fixed_timesteps in timesteps_list: - with torch.set_grad_enabled(False), accelerator.autocast(): - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise = torch.randn_like(latents, device=latents.device) - b_size = latents.shape[0] - timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - if args.masked_loss: - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - total_loss += loss - - average_loss = total_loss / len(timesteps_list) - return average_loss + total_loss = 0.0 + timesteps_list = [10, 350, 500, 650, 990] + + with accelerator.accumulate(*training_models): + with torch.no_grad(): + # latentに変換 + if cache_latents: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + with torch.set_grad_enabled(False), accelerator.autocast(): + if args.weighted_captions: + encoder_hidden_states = get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + + for fixed_timesteps in timesteps_list: + with torch.set_grad_enabled(False), accelerator.autocast(): + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] + timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + total_loss += loss + + average_loss = total_loss / len(timesteps_list) + return average_loss def train(args): train_util.verify_training_args(args) @@ -210,8 +210,8 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) if val_dataset_group is not None: - print("Cache validation latents...") - val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + print("Cache validation latents...") + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -503,25 +503,25 @@ def train(args): avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break From ba08a898940c80a6551111fdd77b53c6d3a019ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 4 Oct 2024 20:35:16 +0900 Subject: [PATCH 167/348] call optimizer eval/train for sample_at_first, also set train after resuming closes #1667 --- flux_train.py | 2 ++ train_network.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/flux_train.py b/flux_train.py index 022467ea7..81c13e4cc 100644 --- a/flux_train.py +++ b/flux_train.py @@ -706,7 +706,9 @@ def optimizer_hook(parameter: torch.Tensor): accelerator.unwrap_model(flux).prepare_block_swap_before_forward() # For --sample_at_first + optimizer_eval_fn() flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb accelerator.log({}, step=0) diff --git a/train_network.py b/train_network.py index 7b2b76a1b..f0d397b9e 100644 --- a/train_network.py +++ b/train_network.py @@ -1042,7 +1042,9 @@ def remove_model(old_ckpt_name): text_encoder = None # For --sample_at_first + optimizer_eval_fn() self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb accelerator.log({}, step=0) From 83e3048cb089bf6726751609da26da751b8383ae Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 6 Oct 2024 21:32:21 +0900 Subject: [PATCH 168/348] load Diffusers format, check schnell/dev --- README.md | 4 + flux_minimal_inference.py | 15 +-- flux_train.py | 15 ++- flux_train_network.py | 17 ++- library/flux_utils.py | 178 +++++++++++++++++++++++++++-- tools/convert_diffusers_to_flux.py | 78 +------------ 6 files changed, 196 insertions(+), 111 deletions(-) diff --git a/README.md b/README.md index 789fe514a..c567758a5 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 6, 2024: +- In FLUX.1 LoRA training and fine-tuning, the specified weight file (*.safetensors) is automatically determined to be dev or schnell. This allows schnell models to be loaded correctly. Note that LoRA training with schnell models and fine-tuning with schnell models are unverified. +- FLUX.1 LoRA training and fine-tuning can now load weights in Diffusers format in addition to BFL format (a single *.safetensors file). Please specify the parent directory of `transformer` or `diffusion_pytorch_model-00001-of-00003.safetensors` with the full path. However, Diffusers format CLIP/T5XXL is not supported. Saving is supported only in BFL format. + Sep 26, 2024: The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 2f1b9a377..7ab224f1b 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -419,9 +419,6 @@ def encode(prpt: str): steps = args.steps guidance_scale = args.guidance - name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way - is_schnell = name == "schnell" - def is_fp8(dt): return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] @@ -455,12 +452,8 @@ def is_fp8(dt): # if is_fp8(t5xxl_dtype): # t5xxl = accelerator.prepare(t5xxl) - t5xxl_max_length = 256 if is_schnell else 512 - tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) - encoding_strategy = strategy_flux.FluxTextEncodingStrategy() - # DiT - model = flux_utils.load_flow_model(name, args.ckpt_path, None, loading_device) + is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device) model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype @@ -469,8 +462,12 @@ def is_fp8(dt): # if args.offload: # model = model.to("cpu") + t5xxl_max_length = 256 if is_schnell else 512 + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) + encoding_strategy = strategy_flux.FluxTextEncodingStrategy() + # AE - ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device) + ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device) ae.eval() # if is_fp8(ae_dtype): # ae = accelerator.prepare(ae) diff --git a/flux_train.py b/flux_train.py index 81c13e4cc..ecc87c0a8 100644 --- a/flux_train.py +++ b/flux_train.py @@ -137,6 +137,7 @@ def train(args): train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) if args.debug_dataset: if args.cache_text_encoder_outputs: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( @@ -144,9 +145,8 @@ def train(args): args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False ) ) - name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" t5xxl_max_token_length = ( - args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if name == "schnell" else 512) + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) ) strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) @@ -177,12 +177,11 @@ def train(args): weight_dtype, save_dtype = train_util.prepare_dtype(args) # モデルを読み込む - name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # load VAE for caching latents ae = None if cache_latents: - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae.to(accelerator.device, dtype=weight_dtype) ae.requires_grad_(False) ae.eval() @@ -196,7 +195,7 @@ def train(args): # prepare tokenize strategy if args.t5xxl_max_token_length is None: - if name == "schnell": + if is_schnell: t5xxl_max_token_length = 256 else: t5xxl_max_token_length = 512 @@ -258,8 +257,8 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - flux = flux_utils.load_flow_model( - name, args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + _, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) if args.gradient_checkpointing: @@ -294,7 +293,7 @@ def train(args): if not cache_latents: # load VAE here if not cached - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu") ae.requires_grad_(False) ae.eval() ae.to(accelerator.device, dtype=weight_dtype) diff --git a/flux_train_network.py b/flux_train_network.py index 65b121e7c..5d14bd28e 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -2,7 +2,7 @@ import copy import math import random -from typing import Any +from typing import Any, Optional import torch from accelerate import Accelerator @@ -24,6 +24,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() self.sample_prompts_te_outputs = None + self.is_schnell: Optional[bool] = None def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -57,19 +58,15 @@ def assert_extra_args(self, args, train_dataset_group): train_dataset_group.verify_bucket_reso_steps(32) # TODO check this - def get_flux_model_name(self, args): - return "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" - def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models - name = self.get_flux_model_name(args) # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) loading_dtype = None if args.fp8_base else weight_dtype # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future - model = flux_utils.load_flow_model( - name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + self.is_schnell, model = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors ) if args.fp8_base: # check dtype of model @@ -100,7 +97,7 @@ def load_target_model(self, args, weight_dtype, accelerator): elif t5xxl.dtype == torch.float8_e4m3fn: logger.info("Loaded fp8 T5XXL model") - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model @@ -142,10 +139,10 @@ def prepare_split_model(self, model, weight_dtype, accelerator): return flux_lower def get_tokenize_strategy(self, args): - name = self.get_flux_model_name(args) + _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) if args.t5xxl_max_token_length is None: - if name == "schnell": + if is_schnell: t5xxl_max_token_length = 256 else: t5xxl_max_token_length = 512 diff --git a/library/flux_utils.py b/library/flux_utils.py index 7b0a41a8a..713814e28 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,9 +1,11 @@ import json -from typing import Optional, Union +import os +from typing import List, Optional, Tuple, Union import einops import torch from safetensors.torch import load_file +from safetensors import safe_open from accelerate import init_empty_weights from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config @@ -17,6 +19,8 @@ logger = logging.getLogger(__name__) MODEL_VERSION_FLUX_V1 = "flux1" +MODEL_NAME_DEV = "dev" +MODEL_NAME_SCHNELL = "schnell" # temporary copy from sd3_utils TODO refactor @@ -39,10 +43,35 @@ def load_safetensors( return load_file(path) # prevent device invalid Error +def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, List[str]]: + # check the state dict: Diffusers or BFL, dev or schnell + logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") + + if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers + ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors") + if "00001-of-00003" in ckpt_path: + ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)] + else: + ckpt_paths = [ckpt_path] + + keys = [] + for ckpt_path in ckpt_paths: + with safe_open(ckpt_path, framework="pt") as f: + keys.extend(f.keys()) + + is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys + is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) + return is_diffusers, is_schnell, ckpt_paths + + def load_flow_model( - name: str, ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False -) -> flux_models.Flux: - logger.info(f"Building Flux model {name}") + ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False +) -> Tuple[bool, flux_models.Flux]: + is_diffusers, is_schnell, ckpt_paths = check_flux_state_dict_diffusers_schnell(ckpt_path) + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + + # build model + logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") with torch.device("meta"): model = flux_models.Flux(flux_models.configs[name].params) if dtype is not None: @@ -50,18 +79,28 @@ def load_flow_model( # load_sft doesn't support torch.device logger.info(f"Loading state dict from {ckpt_path}") - sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + sd = {} + for ckpt_path in ckpt_paths: + sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)) + + # convert Diffusers to BFL + if is_diffusers: + logger.info("Converting Diffusers to BFL") + sd = convert_diffusers_sd_to_bfl(sd) + logger.info("Converted Diffusers to BFL") + info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Flux: {info}") - return model + return is_schnell, model def load_ae( - name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False + ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False ) -> flux_models.AutoEncoder: logger.info("Building AutoEncoder") with torch.device("meta"): - ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype) + # dev and schnell have the same AE params + ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype) logger.info(f"Loading state dict from {ckpt_path}") sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) @@ -246,3 +285,126 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor: """ x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) return x + + +# region Diffusers + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + +BFL_TO_DIFFUSERS_MAP = { + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]: + # make reverse map from diffusers map + diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) + for b in range(NUM_DOUBLE_BLOCKS): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("double_blocks."): + block_prefix = f"transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for b in range(NUM_SINGLE_BLOCKS): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("single_blocks."): + block_prefix = f"single_transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): + for i, weight in enumerate(weights): + diffusers_to_bfl_map[weight] = (i, key) + return diffusers_to_bfl_map + + +def convert_diffusers_sd_to_bfl(diffusers_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + diffusers_to_bfl_map = make_diffusers_to_bfl_map() + + # iterate over three safetensors files to reduce memory usage + flux_sd = {} + for diffusers_key, tensor in diffusers_sd.items(): + if diffusers_key in diffusers_to_bfl_map: + index, bfl_key = diffusers_to_bfl_map[diffusers_key] + if bfl_key not in flux_sd: + flux_sd[bfl_key] = [] + flux_sd[bfl_key].append((index, tensor)) + else: + logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}") + raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}") + + # concat tensors if multiple tensors are mapped to a single key, sort by index + for key, values in flux_sd.items(): + if len(values) == 1: + flux_sd[key] = values[0][1] + else: + flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])]) + + # special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + if "final_layer.adaLN_modulation.1.weight" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"]) + if "final_layer.adaLN_modulation.1.bias" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"]) + + return flux_sd + + +# endregion diff --git a/tools/convert_diffusers_to_flux.py b/tools/convert_diffusers_to_flux.py index 9d8f7c74b..65ba7321a 100644 --- a/tools/convert_diffusers_to_flux.py +++ b/tools/convert_diffusers_to_flux.py @@ -29,6 +29,7 @@ import torch from tqdm import tqdm +from library import flux_utils from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file setup_logging() @@ -36,65 +37,6 @@ logger = logging.getLogger(__name__) -NUM_DOUBLE_BLOCKS = 19 -NUM_SINGLE_BLOCKS = 38 - -BFL_TO_DIFFUSERS_MAP = { - "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], - "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], - "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], - "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], - "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], - "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], - "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], - "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], - "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], - "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], - "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], - "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], - "txt_in.weight": ["context_embedder.weight"], - "txt_in.bias": ["context_embedder.bias"], - "img_in.weight": ["x_embedder.weight"], - "img_in.bias": ["x_embedder.bias"], - "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], - "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], - "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], - "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], - "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], - "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], - "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], - "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], - "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], - "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], - "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], - "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], - "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], - "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], - "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], - "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], - "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], - "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], - "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], - "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], - "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], - "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], - "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], - "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], - "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], - "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], - "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], - "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], - "single_blocks.().linear2.weight": ["proj_out.weight"], - "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], - "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], - "single_blocks.().linear2.weight": ["proj_out.weight"], - "single_blocks.().linear2.bias": ["proj_out.bias"], - "final_layer.linear.weight": ["proj_out.weight"], - "final_layer.linear.bias": ["proj_out.bias"], - "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], - "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], -} - def convert(args): # if diffusers_path is folder, get safetensors file @@ -114,23 +56,7 @@ def convert(args): save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None # make reverse map from diffusers map - diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) - for b in range(NUM_DOUBLE_BLOCKS): - for key, weights in BFL_TO_DIFFUSERS_MAP.items(): - if key.startswith("double_blocks."): - block_prefix = f"transformer_blocks.{b}." - for i, weight in enumerate(weights): - diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) - for b in range(NUM_SINGLE_BLOCKS): - for key, weights in BFL_TO_DIFFUSERS_MAP.items(): - if key.startswith("single_blocks."): - block_prefix = f"single_transformer_blocks.{b}." - for i, weight in enumerate(weights): - diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) - for key, weights in BFL_TO_DIFFUSERS_MAP.items(): - if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): - for i, weight in enumerate(weights): - diffusers_to_bfl_map[weight] = (i, key) + diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map() # iterate over three safetensors files to reduce memory usage flux_sd = {} From 886f75345c95cddec8752ffdd4e60a471ee75403 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 10 Oct 2024 08:27:15 +0900 Subject: [PATCH 169/348] support weighted captions for sdxl LoRA and fine tuning --- library/strategy_base.py | 5 ++++- library/strategy_sdxl.py | 3 ++- sdxl_train.py | 38 ++++++++++++++++++++------------------ sdxl_train_control_net.py | 7 ++----- train_network.py | 27 +++++++++++++++++---------- 5 files changed, 45 insertions(+), 35 deletions(-) diff --git a/library/strategy_base.py b/library/strategy_base.py index 10820afa1..7981bd0b9 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -74,6 +74,9 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: raise NotImplementedError def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + returns: [tokens1, tokens2, ...], [weights1, weights2, ...] + """ raise NotImplementedError def _get_weighted_input_ids( @@ -303,7 +306,7 @@ def encode_tokens( :return: list of output embeddings for each architecture """ raise NotImplementedError - + def encode_tokens_with_weights( self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor] ) -> List[torch.Tensor]: diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index b48e6d55a..6650e2b43 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -174,7 +174,8 @@ def encode_tokens( """ Args: tokenize_strategy: TokenizeStrategy - models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)] + models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]. + If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required tokens: List of tokens, for text_encoder1 and text_encoder2 """ if len(models) == 2: diff --git a/sdxl_train.py b/sdxl_train.py index 7291ddd2f..320169d77 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -104,8 +104,8 @@ def train(args): setup_logging(args, reset=True) assert ( - not args.weighted_captions - ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + not args.weighted_captions or not args.cache_text_encoder_outputs + ), "weighted_captions is not supported when caching text encoder outputs / cache_text_encoder_outputsを使うときはweighted_captionsはサポートされていません" assert ( not args.train_text_encoder or not args.cache_text_encoder_outputs ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" @@ -660,22 +660,24 @@ def optimizer_hook(parameter: torch.Tensor): input_ids1, input_ids2 = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning - # TODO support weighted captions - # if args.weighted_captions: - # encoder_hidden_states = get_weighted_text_embeddings( - # tokenizer, - # text_encoder, - # batch["captions"], - # accelerator.device, - # args.max_token_length // 75 if args.max_token_length else 1, - # clip_skip=args.clip_skip, - # ) - # else: - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( - tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] - ) + if args.weighted_captions: + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states1, encoder_hidden_states2, pool2 = ( + text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + [text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)], + input_ids_list, + weights_list, + ) + ) + else: + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, + [text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)], + [input_ids1, input_ids2], + ) if args.full_fp16: encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index b902cda69..f6cc5a4f9 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -12,24 +12,21 @@ init_ipex() -from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from accelerate import init_empty_weights -from diffusers import DDPMScheduler, ControlNetModel +from diffusers import DDPMScheduler from diffusers.utils.torch_utils import is_compiled_module from safetensors.torch import load_file from library import ( deepspeed_utils, sai_model_spec, sdxl_model_util, - sdxl_original_unet, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, ) -import library.model_util as model_util import library.train_util as train_util import library.config_util as config_util from library.config_util import ( @@ -264,7 +261,7 @@ def unwrap_model(model): trainable_params.append({"params": ctrlnet_params, "lr": args.control_net_lr}) trainable_params.append({"params": unet_params, "lr": args.learning_rate}) all_params = ctrlnet_params + unet_params - + logger.info(f"trainable params count: {len(all_params)}") logger.info(f"number of trainable parameters: {sum(p.numel() for p in all_params)}") diff --git a/train_network.py b/train_network.py index f0d397b9e..e48e6a070 100644 --- a/train_network.py +++ b/train_network.py @@ -1123,14 +1123,21 @@ def remove_model(old_ckpt_name): with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: - # SD only - encoded_text_encoder_conds = get_weighted_text_embeddings( - tokenizers[0], - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, + # # SD only + # encoded_text_encoder_conds = get_weighted_text_embeddings( + # tokenizers[0], + # text_encoder, + # batch["captions"], + # accelerator.device, + # args.max_token_length // 75 if args.max_token_length else 1, + # clip_skip=args.clip_skip, + # ) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids_list, + weights_list, ) else: input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] @@ -1139,8 +1146,8 @@ def remove_model(old_ckpt_name): self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids, ) - if args.full_fp16: - encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + if args.full_fp16: + encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] # if text_encoder_conds is not cached, use encoded_text_encoder_conds if len(text_encoder_conds) == 0: From 3de42b6edb151b172f483aec99fe380b1406a84a Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 10 Oct 2024 14:03:59 +0800 Subject: [PATCH 170/348] fix: distributed training in windows --- library/train_util.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index e023f63a2..3dabf9e26 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5045,17 +5045,18 @@ def prepare_accelerator(args: argparse.Namespace): if args.torch_compile: dynamo_backend = args.dynamo_backend - kwargs_handlers = ( - InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, - ( - DistributedDataParallelKwargs( - gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph - ) - if args.ddp_gradient_as_bucket_view or args.ddp_static_graph - else None - ), - ) - kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) + kwargs_handlers = [ + InitProcessGroupKwargs( + backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method="env://?use_libuv=False" if os.name == "nt" else None, + timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None + ) if torch.cuda.device_count() > 1 else None, + DistributedDataParallelKwargs( + gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, + static_graph=args.ddp_static_graph + ) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None + ] + kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) accelerator = Accelerator( From 9f4dac5731fe2299c75b7671c6132febd57a4117 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 10 Oct 2024 14:08:55 +0800 Subject: [PATCH 171/348] torch 2.4 --- library/train_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 3dabf9e26..2c20a9244 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -33,6 +33,7 @@ import toml from tqdm import tqdm +from packaging.version import Version import torch from library.device_utils import init_ipex, clean_memory_on_device @@ -5048,7 +5049,7 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers = [ InitProcessGroupKwargs( backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", - init_method="env://?use_libuv=False" if os.name == "nt" else None, + init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None, timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None ) if torch.cuda.device_count() > 1 else None, DistributedDataParallelKwargs( From f2bc8201330d1370c182c57047a5c23e9c6bee71 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 11 Oct 2024 08:48:55 +0900 Subject: [PATCH 172/348] support weighted captions for SD/SDXL --- fine_tune.py | 17 ++++-------- library/sdxl_train_util.py | 6 ++-- library/strategy_base.py | 12 +++++++- library/strategy_sd.py | 36 ++++++++++++++++++++++++ library/strategy_sdxl.py | 57 ++++++++++++++++++++++++++------------ sdxl_train.py | 2 +- sdxl_train_network.py | 4 ++- train_db.py | 16 ++++------- 8 files changed, 105 insertions(+), 45 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 62a545a13..fd63385b3 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -366,22 +366,17 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning if args.weighted_captions: - # TODO move to strategy_sd.py - encoder_hidden_states = get_weighted_text_embeddings( - tokenize_strategy.tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, [text_encoder], input_ids_list, weights_list + )[0] else: input_ids = batch["input_ids_list"][0].to(accelerator.device) encoder_hidden_states = text_encoding_strategy.encode_tokens( tokenize_strategy, [text_encoder], [input_ids] )[0] - if args.full_fp16: - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index aaf77b8dd..dc3887c34 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -363,9 +363,9 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin # ) # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") - assert ( - not hasattr(args, "weighted_captions") or not args.weighted_captions - ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" + # assert ( + # not hasattr(args, "weighted_captions") or not args.weighted_captions + # ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" if supportTextEncoderCaching: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: diff --git a/library/strategy_base.py b/library/strategy_base.py index 7981bd0b9..2bff4178a 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -323,12 +323,18 @@ class TextEncoderOutputsCachingStrategy: _strategy = None # strategy instance: actual strategy class def __init__( - self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + is_weighted: bool = False, ) -> None: self._cache_to_disk = cache_to_disk self._batch_size = batch_size self.skip_disk_cache_validity_check = skip_disk_cache_validity_check self._is_partial = is_partial + self._is_weighted = is_weighted @classmethod def set_strategy(cls, strategy): @@ -352,6 +358,10 @@ def batch_size(self): def is_partial(self): return self._is_partial + @property + def is_weighted(self): + return self._is_weighted + def get_outputs_npz_path(self, image_abs_path: str) -> str: raise NotImplementedError diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 83ffaa31b..4e7931fdb 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -40,6 +40,16 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: text = [text] if isinstance(text, str) else text return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + text = [text] if isinstance(text, str) else text + tokens_list = [] + weights_list = [] + for t in text: + tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True) + tokens_list.append(tokens) + weights_list.append(weights) + return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)] + class SdTextEncodingStrategy(TextEncodingStrategy): def __init__(self, clip_skip: Optional[int] = None) -> None: @@ -58,6 +68,8 @@ def encode_tokens( model_max_length = sd_tokenize_strategy.tokenizer.model_max_length tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77 + tokens = tokens.to(text_encoder.device) + if self.clip_skip is None: encoder_hidden_states = text_encoder(tokens)[0] else: @@ -93,6 +105,30 @@ def encode_tokens( return [encoder_hidden_states] + def encode_tokens_with_weights( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens_list: List[torch.Tensor], + weights_list: List[torch.Tensor], + ) -> List[torch.Tensor]: + encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0] + + weights = weights_list[0].to(encoder_hidden_states.device) + + # apply weights + if weights.shape[1] == 1: # no max_token_length + # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) + encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2) + else: + # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) + for i in range(weights.shape[1]): + encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[ + :, i, 1:-1 + ].unsqueeze(-1) + + return [encoder_hidden_states] + class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix. diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index 6650e2b43..6b3e2afa6 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -42,16 +42,16 @@ def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tenso tokens1_list, tokens2_list = [], [] weights1_list, weights2_list = [], [] for t in text: - tokens1, weights1 = self._get_weighted_input_ids(self.tokenizer1, t, self.max_length) - tokens2, weights2 = self._get_weighted_input_ids(self.tokenizer2, t, self.max_length) + tokens1, weights1 = self._get_input_ids(self.tokenizer1, t, self.max_length, weighted=True) + tokens2, weights2 = self._get_input_ids(self.tokenizer2, t, self.max_length, weighted=True) tokens1_list.append(tokens1) tokens2_list.append(tokens2) weights1_list.append(weights1) weights2_list.append(weights2) - return (torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)), ( + return [torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)], [ torch.stack(weights1_list, dim=0), torch.stack(weights2_list, dim=0), - ) + ] class SdxlTextEncodingStrategy(TextEncodingStrategy): @@ -193,20 +193,28 @@ def encode_tokens( return [hidden_states1, hidden_states2, pool2] def encode_tokens_with_weights( - self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor] + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens_list: List[torch.Tensor], + weights_list: List[torch.Tensor], ) -> List[torch.Tensor]: - hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens) + hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens_list) + + weights_list = [weights.to(hidden_states1.device) for weights in weights_list] # apply weights - if weights[0].shape[1] == 1: # no max_token_length + if weights_list[0].shape[1] == 1: # no max_token_length # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) - hidden_states1 = hidden_states1 * weights[0].squeeze(1).unsqueeze(2) - hidden_states2 = hidden_states2 * weights[1].squeeze(1).unsqueeze(2) + hidden_states1 = hidden_states1 * weights_list[0].squeeze(1).unsqueeze(2) + hidden_states2 = hidden_states2 * weights_list[1].squeeze(1).unsqueeze(2) else: # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) - for weight, hidden_states in zip(weights, [hidden_states1, hidden_states2]): + for weight, hidden_states in zip(weights_list, [hidden_states1, hidden_states2]): for i in range(weight.shape[1]): - hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[:, i, 1:-1] + hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[ + :, i, 1:-1 + ].unsqueeze(-1) return [hidden_states1, hidden_states2, pool2] @@ -215,9 +223,14 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" def __init__( - self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + is_weighted: bool = False, ) -> None: - super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted) def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX @@ -253,11 +266,19 @@ def cache_batch_outputs( sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy captions = [info.caption for info in infos] - tokens1, tokens2 = tokenize_strategy.tokenize(captions) - with torch.no_grad(): - hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, [tokens1, tokens2] - ) + if self.is_weighted: + tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, models, tokens_list, weights_list + ) + else: + tokens1, tokens2 = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, [tokens1, tokens2] + ) + if hidden_state1.dtype == torch.bfloat16: hidden_state1 = hidden_state1.float() if hidden_state2.dtype == torch.bfloat16: diff --git a/sdxl_train.py b/sdxl_train.py index 320169d77..aeff9c469 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -321,7 +321,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False + args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 4d6e3f184..20e32155c 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -79,7 +79,9 @@ def get_models_for_text_encoding(self, args, accelerator, text_encoders): def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: - return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False) + return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions + ) else: return None diff --git a/train_db.py b/train_db.py index a5d520b12..e49a7e70f 100644 --- a/train_db.py +++ b/train_db.py @@ -356,21 +356,17 @@ def train(args): # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenize_strategy.tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, [text_encoder], input_ids_list, weights_list + )[0] else: input_ids = batch["input_ids_list"][0].to(accelerator.device) encoder_hidden_states = text_encoding_strategy.encode_tokens( tokenize_strategy, [text_encoder], [input_ids] )[0] - if args.full_fp16: - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified From 035c4a8552bf6214ad4d39657d3eb1204cdecdfd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 11 Oct 2024 22:23:15 +0900 Subject: [PATCH 173/348] update docs and help text --- README.md | 10 ++++++++++ docs/train_lllite_README.md | 2 +- sdxl_train_control_net.py | 2 +- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c567758a5..d3f49c994 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,16 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 11, 2024: +- ControlNet training for SDXL has been implemented in this branch. Please use `sdxl_train_control_net.py`. + - For details on defining the dataset, see [here](docs/train_lllite_README.md#creating-a-dataset-configuration-file). + - The learning rate for the copy part of the U-Net is specified by `--learning_rate`. The learning rate for the added modules in ControlNet is specified by `--control_net_lr`. The optimal value is still unknown, but try around U-Net `1e-5` and ControlNet `1e-4`. + - If you want to generate sample images, specify the control image as `--cn path/to/control/image`. + - The trained weights are automatically converted and saved in Diffusers format. It should be available in ComfyUI. +- Weighting of prompts (captions) during training in SDXL is now supported (e.g., `(some text)`, `[some text]`, `(some text:1.4)`, etc.). The function is enabled by specifying `--weighted_captions`. + - The default is `False`. It is same as before, and the parentheses are used as normal text. + - If `--weighted_captions` is specified, please use `\` to escape the parentheses in the prompt. For example, `\(some text:1.4\)`. + Oct 6, 2024: - In FLUX.1 LoRA training and fine-tuning, the specified weight file (*.safetensors) is automatically determined to be dev or schnell. This allows schnell models to be loaded correctly. Note that LoRA training with schnell models and fine-tuning with schnell models are unverified. - FLUX.1 LoRA training and fine-tuning can now load weights in Diffusers format in addition to BFL format (a single *.safetensors file). Please specify the parent directory of `transformer` or `diffusion_pytorch_model-00001-of-00003.safetensors` with the full path. However, Diffusers format CLIP/T5XXL is not supported. Saving is supported only in BFL format. diff --git a/docs/train_lllite_README.md b/docs/train_lllite_README.md index a05f87f5f..1bd8e4ae1 100644 --- a/docs/train_lllite_README.md +++ b/docs/train_lllite_README.md @@ -185,7 +185,7 @@ for img_file in img_files: ### Creating a dataset configuration file -You can use the command line arguments of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`. +You can use the command line argument `--conditioning_data_dir` of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`. ```toml [general] diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index f6cc5a4f9..67c8d52c8 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -705,7 +705,7 @@ def setup_parser() -> argparse.ArgumentParser: "--control_net_lr", type=float, default=1e-4, - help="learning rate for controlnet / controlnetの学習率", + help="learning rate for controlnet modules / controlnetモジュールの学習率", ) return parser From 0d3058b65ab7cd827e44f16f84c68a4bb73f701e Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 12 Oct 2024 14:46:35 +0900 Subject: [PATCH 174/348] update README --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index d3f49c994..37fc911f6 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,17 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 12, 2024: + +- Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! + - It should work with all training scripts, but it is unverified. + - Set up multi-GPU training with `accelerate config`. + - Specify `--rdzv_backend=c10d` when launching `accelerate launch`. You can also edit `config.yaml` directly. + ``` + accelerate launch --rdzv_backend=c10d sdxl_train_network.py ... + ``` + - In multi-GPU training, the memory of multiple GPUs is not integrated. In other words, even if you have two 12GB VRAM GPUs, you cannot train the model that requires 24GB VRAM. Training that can be done with 12GB VRAM is executed at (up to) twice the speed. + Oct 11, 2024: - ControlNet training for SDXL has been implemented in this branch. Please use `sdxl_train_control_net.py`. - For details on defining the dataset, see [here](docs/train_lllite_README.md#creating-a-dataset-configuration-file). From c80c304779775f4d00fd8f4856bfc8e6599e2de0 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 12 Oct 2024 20:18:41 +0900 Subject: [PATCH 175/348] Refactor caching in train scripts --- README.md | 10 +++++ fine_tune.py | 2 +- flux_train.py | 14 ++++--- flux_train_network.py | 6 +-- library/train_util.py | 64 +++++++++++++++++++++++--------- sd3_train.py | 17 +++++++-- sdxl_train.py | 4 +- sdxl_train_control_net.py | 4 +- sdxl_train_control_net_lllite.py | 5 +-- sdxl_train_network.py | 8 ++-- sdxl_train_textual_inversion.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- 14 files changed, 95 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index 37fc911f6..2b2562831 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,16 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 12, 2024 (update 1): + +- During multi-GPU training, caching of latents and Text Encoder outputs is now done in multi-GPU. +- `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching. +- `--skip_cache_check` option is added to each training script. + - When specified, the consistency check of the cache file `*.npz` contents (e.g., image size and flip for latents, mask for Text Encoder outputs) is skipped. + - Specify this option if you have a large number of cache files and the consistency check takes time. + - Even if this option is specified, the cache will be created if the file does not exist. + - `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead. + Oct 12, 2024: - Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! diff --git a/fine_tune.py b/fine_tune.py index fd63385b3..cdc005d9a 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -59,7 +59,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if cache_latents: latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) diff --git a/flux_train.py b/flux_train.py index ecc87c0a8..e18a92443 100644 --- a/flux_train.py +++ b/flux_train.py @@ -57,6 +57,10 @@ def train(args): deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + # assert ( # not args.weighted_captions # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" @@ -81,7 +85,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if args.cache_latents: latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( - args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -142,7 +146,7 @@ def train(args): if args.cache_text_encoder_outputs: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( strategy_flux.FluxTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False ) ) t5xxl_max_token_length = ( @@ -181,7 +185,7 @@ def train(args): # load VAE for caching latents ae = None if cache_latents: - ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae.to(accelerator.device, dtype=weight_dtype) ae.requires_grad_(False) ae.eval() @@ -229,7 +233,7 @@ def train(args): strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator) # cache sample prompt's embeddings to free text encoder's memory if args.sample_prompts is not None: @@ -952,7 +956,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--skip_latents_validity_check", action="store_true", - help="skip latents validity check / latentsの正当性チェックをスキップする", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) parser.add_argument( "--blocks_to_swap", diff --git a/flux_train_network.py b/flux_train_network.py index 5d14bd28e..3bd8316d4 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -188,8 +188,8 @@ def get_text_encoder_outputs_caching_strategy(self, args): # if the text encoders is trained, we need tokenization, so is_partial is True return strategy_flux.FluxTextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, - None, - False, + args.text_encoder_batch_size, + args.skip_cache_check, is_partial=self.train_clip_l or self.train_t5xxl, apply_t5_attn_mask=args.apply_t5_attn_mask, ) @@ -222,7 +222,7 @@ def cache_text_encoder_outputs_if_needed( text_encoders[1].to(weight_dtype) with accelerator.autocast(): - dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) # cache sample prompts if args.sample_prompts is not None: diff --git a/library/train_util.py b/library/train_util.py index 67eaae41b..4e6b3408d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -31,6 +31,7 @@ import subprocess from io import BytesIO import toml + # from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm @@ -1192,7 +1193,7 @@ def __eq__(self, other): for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) - def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): r""" a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy. """ @@ -1207,15 +1208,25 @@ def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: boo # split by resolution batches = [] batch = [] - logger.info("checking cache validity...") - for info in tqdm(image_infos): - te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) - # check disk cache exists and size of latents + # support multiple-gpus + num_processes = accelerator.num_processes + process_index = accelerator.process_index + + logger.info("checking cache validity...") + for i, info in enumerate(tqdm(image_infos)): + # check disk cache exists and size of text encoder outputs if caching_strategy.cache_to_disk: - info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process + te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) + info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability + + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different text encoder outputs + if i % num_processes != process_index: + continue + cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) - if cache_available or not is_main_process: # do not add to batch + if cache_available: # do not add to batch continue batch.append(info) @@ -2420,6 +2431,7 @@ def new_cache_latents(self, model: Any, accelerator: Accelerator): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") dataset.new_cache_latents(model, accelerator) + accelerator.wait_for_everyone() def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True @@ -2437,10 +2449,11 @@ def cache_text_encoder_outputs_sd3( tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size ) - def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.new_cache_text_encoder_outputs(models, is_main_process) + dataset.new_cache_text_encoder_outputs(models, accelerator) + accelerator.wait_for_everyone() def set_caching_mode(self, caching_mode): for dataset in self.datasets: @@ -4210,6 +4223,12 @@ def add_dataset_arguments( action="store_true", help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", ) + parser.add_argument( + "--skip_cache_check", + action="store_true", + help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist" + " / cacheの内容の検証をスキップする(latentとテキストエンコーダの出力)。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる", + ) parser.add_argument( "--enable_bucket", action="store_true", @@ -5084,15 +5103,24 @@ def prepare_accelerator(args: argparse.Namespace): dynamo_backend = args.dynamo_backend kwargs_handlers = [ - InitProcessGroupKwargs( - backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", - init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None, - timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None - ) if torch.cuda.device_count() > 1 else None, - DistributedDataParallelKwargs( - gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, - static_graph=args.ddp_static_graph - ) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None + ( + InitProcessGroupKwargs( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method=( + "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None + ), + timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None, + ) + if torch.cuda.device_count() > 1 + else None + ), + ( + DistributedDataParallelKwargs( + gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph + ) + if args.ddp_gradient_as_bucket_view or args.ddp_static_graph + else None + ), ] kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) diff --git a/sd3_train.py b/sd3_train.py index 5120105f2..7290956ad 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -57,6 +57,10 @@ def train(args): deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + assert ( not args.weighted_captions ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" @@ -103,7 +107,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if args.cache_latents: latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( - args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -312,7 +316,7 @@ def train(args): text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, - False, + args.skip_cache_check, train_clip_g or train_clip_l or args.use_t5xxl_cache_only, args.apply_lg_attn_mask, args.apply_t5_attn_mask, @@ -325,7 +329,7 @@ def train(args): t5xxl.to(t5xxl_device, dtype=t5xxl_dtype) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator) # cache sample prompt's embeddings to free text encoder's memory if args.sample_prompts is not None: @@ -1052,7 +1056,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--skip_latents_validity_check", action="store_true", - help="skip latents validity check / latentsの正当性チェックをスキップする", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--skip_cache_check", + action="store_true", + help="skip cache (latents and text encoder outputs) check / キャッシュ(latentsとtext encoder outputs)のチェックをスキップする", ) parser.add_argument( "--num_last_block_to_freeze", diff --git a/sdxl_train.py b/sdxl_train.py index aeff9c469..9b2d19165 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -131,7 +131,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if args.cache_latents: latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -328,7 +328,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) accelerator.wait_for_everyone() diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 67c8d52c8..74b3a64a4 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -84,7 +84,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -230,7 +230,7 @@ def unwrap_model(model): text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) accelerator.wait_for_everyone() diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 9d1cfc63e..14ff7c240 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -93,7 +93,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -202,7 +202,7 @@ def train(args): text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) accelerator.wait_for_everyone() @@ -431,7 +431,6 @@ def remove_model(old_ckpt_name): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: # Text Encoder outputs are cached diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 20e32155c..4a16a4891 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -67,7 +67,7 @@ def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy @@ -80,7 +80,7 @@ def get_models_for_text_encoding(self, args, accelerator, text_encoders): def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions + args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions ) else: return None @@ -102,9 +102,7 @@ def cache_text_encoder_outputs_if_needed( text_encoders[0].to(accelerator.device, dtype=weight_dtype) text_encoders[1].to(accelerator.device, dtype=weight_dtype) with accelerator.autocast(): - dataset.new_cache_text_encoder_outputs( - text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process - ) + dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator) accelerator.wait_for_everyone() text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index cbfcef554..821a69558 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -49,7 +49,7 @@ def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy diff --git a/train_db.py b/train_db.py index e49a7e70f..683b42332 100644 --- a/train_db.py +++ b/train_db.py @@ -64,7 +64,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) diff --git a/train_network.py b/train_network.py index 7437157b9..d5330aef4 100644 --- a/train_network.py +++ b/train_network.py @@ -116,7 +116,7 @@ def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> L def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - True, args.cache_latents_to_disk, args.vae_batch_size, False + True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 3b3d3393f..4d8a3abbf 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -114,7 +114,7 @@ def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> L def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - True, args.cache_latents_to_disk, args.vae_batch_size, False + True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy From ecaea909b10fa8b3eb94a1cf57b26d5daba1683e Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 12 Oct 2024 20:26:57 +0900 Subject: [PATCH 176/348] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 37fc911f6..9128bf8da 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ The command to install PyTorch is as follows: Oct 12, 2024: - Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! - - It should work with all training scripts, but it is unverified. + - In simple tests, SDXL and FLUX.1 LoRA training worked. FLUX.1 fine-tuning did not work, probably due to a PyTorch-related error. Other scripts are unverified. - Set up multi-GPU training with `accelerate config`. - Specify `--rdzv_backend=c10d` when launching `accelerate launch`. You can also edit `config.yaml` directly. ``` From e277b5789e791539b5e51187530f11bd94e24871 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 12 Oct 2024 21:49:07 +0900 Subject: [PATCH 177/348] Update FLUX.1 support for compact models --- README.md | 10 ++++++ flux_train.py | 12 +++---- flux_train_network.py | 2 +- library/flux_utils.py | 76 ++++++++++++++++++++++++++++++++++++------- 4 files changed, 82 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 9128bf8da..b64515a19 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,16 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 12, 2024 (update 1): + +- [Experimental] FLUX.1 fine-tuning and LoRA training now support "FLUX.1 __compact__" models. + - A compact model is a model that retains the FLUX.1 architecture but reduces the number of double/single blocks from the default 19/38. + - The model is automatically determined based on the keys in *.safetensors. + - Specifications for compact model safetensors: + - Please specify the block indices as consecutive numbers. An error will occur if there are missing numbers. For example, if you reduce the double blocks to 15, the maximum key will be `double_blocks.14.*`. The same applies to single blocks. + - LoRA training is unverified. + - The trained model can be used for inference with `flux_minimal_inference.py`. Other inference environments are unverified. + Oct 12, 2024: - Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! diff --git a/flux_train.py b/flux_train.py index ecc87c0a8..2fc13068e 100644 --- a/flux_train.py +++ b/flux_train.py @@ -137,7 +137,7 @@ def train(args): train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 - _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) if args.debug_dataset: if args.cache_text_encoder_outputs: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( @@ -181,7 +181,7 @@ def train(args): # load VAE for caching latents ae = None if cache_latents: - ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae.to(accelerator.device, dtype=weight_dtype) ae.requires_grad_(False) ae.eval() @@ -510,8 +510,8 @@ def wait_blocks_move(block_idx, futures): library.adafactor_fused.patch_adafactor_fused(optimizer) blocks_to_swap = args.blocks_to_swap - num_double_blocks = 19 # len(flux.double_blocks) - num_single_blocks = 38 # len(flux.single_blocks) + num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) + num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) num_block_units = num_double_blocks + num_single_blocks // 2 handled_unit_indices = set() @@ -603,8 +603,8 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): parameter_optimizer_map = {} blocks_to_swap = args.blocks_to_swap - num_double_blocks = 19 # len(flux.double_blocks) - num_single_blocks = 38 # len(flux.single_blocks) + num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) + num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) num_block_units = num_double_blocks + num_single_blocks // 2 n = 1 # only asynchronous purpose, no need to increase this number diff --git a/flux_train_network.py b/flux_train_network.py index 5d14bd28e..a24c1905b 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -139,7 +139,7 @@ def prepare_split_model(self, model, weight_dtype, accelerator): return flux_lower def get_tokenize_strategy(self, args): - _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) if args.t5xxl_max_token_length is None: if is_schnell: diff --git a/library/flux_utils.py b/library/flux_utils.py index 713814e28..7a1ec37b8 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,3 +1,4 @@ +from dataclasses import replace import json import os from typing import List, Optional, Tuple, Union @@ -43,8 +44,21 @@ def load_safetensors( return load_file(path) # prevent device invalid Error -def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, List[str]]: - # check the state dict: Diffusers or BFL, dev or schnell +def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: + """ + チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。 + + Args: + ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。 + + Returns: + Tuple[bool, bool, Tuple[int, int], List[str]]: + - bool: Diffusersかどうかを示すフラグ。 + - bool: Schnellかどうかを示すフラグ。 + - Tuple[int, int]: ダブルブロックとシングルブロックの数。 + - List[str]: チェックポイントに含まれるキーのリスト。 + """ + # check the state dict: Diffusers or BFL, dev or schnell, number of blocks logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers @@ -61,19 +75,57 @@ def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) - return is_diffusers, is_schnell, ckpt_paths + + # check number of double and single blocks + if not is_diffusers: + max_double_block_index = max( + [int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")] + ) + max_single_block_index = max( + [int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")] + ) + else: + max_double_block_index = max( + [ + int(key.split(".")[1]) + for key in keys + if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias") + ] + ) + max_single_block_index = max( + [ + int(key.split(".")[1]) + for key in keys + if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias") + ] + ) + + num_double_blocks = max_double_block_index + 1 + num_single_blocks = max_single_block_index + 1 + + return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths def load_flow_model( ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False ) -> Tuple[bool, flux_models.Flux]: - is_diffusers, is_schnell, ckpt_paths = check_flux_state_dict_diffusers_schnell(ckpt_path) + is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL # build model logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") with torch.device("meta"): - model = flux_models.Flux(flux_models.configs[name].params) + params = flux_models.configs[name].params + + # set the number of blocks + if params.depth != num_double_blocks: + logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}") + params = replace(params, depth=num_double_blocks) + if params.depth_single_blocks != num_single_blocks: + logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}") + params = replace(params, depth_single_blocks=num_single_blocks) + + model = flux_models.Flux(params) if dtype is not None: model = model.to(dtype) @@ -86,7 +138,7 @@ def load_flow_model( # convert Diffusers to BFL if is_diffusers: logger.info("Converting Diffusers to BFL") - sd = convert_diffusers_sd_to_bfl(sd) + sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) logger.info("Converted Diffusers to BFL") info = model.load_state_dict(sd, strict=False, assign=True) @@ -349,16 +401,16 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor: } -def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]: +def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]: # make reverse map from diffusers map diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) - for b in range(NUM_DOUBLE_BLOCKS): + for b in range(num_double_blocks): for key, weights in BFL_TO_DIFFUSERS_MAP.items(): if key.startswith("double_blocks."): block_prefix = f"transformer_blocks.{b}." for i, weight in enumerate(weights): diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) - for b in range(NUM_SINGLE_BLOCKS): + for b in range(num_single_blocks): for key, weights in BFL_TO_DIFFUSERS_MAP.items(): if key.startswith("single_blocks."): block_prefix = f"single_transformer_blocks.{b}." @@ -371,8 +423,10 @@ def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]: return diffusers_to_bfl_map -def convert_diffusers_sd_to_bfl(diffusers_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - diffusers_to_bfl_map = make_diffusers_to_bfl_map() +def convert_diffusers_sd_to_bfl( + diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS +) -> dict[str, torch.Tensor]: + diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks) # iterate over three safetensors files to reduce memory usage flux_sd = {} From 74228c9953b4ba0f8b0d68e8f6c8a8a6a469c2f5 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 13 Oct 2024 16:27:22 +0900 Subject: [PATCH 178/348] update cache_latents/text_encoder_outputs --- library/strategy_base.py | 2 +- tools/cache_latents.py | 147 +++++++++++------------ tools/cache_text_encoder_outputs.py | 178 ++++++++++++++++------------ 3 files changed, 166 insertions(+), 161 deletions(-) diff --git a/library/strategy_base.py b/library/strategy_base.py index 2bff4178a..363996cec 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -325,7 +325,7 @@ class TextEncoderOutputsCachingStrategy: def __init__( self, cache_to_disk: bool, - batch_size: int, + batch_size: Optional[int], skip_disk_cache_validity_check: bool, is_partial: bool = False, is_weighted: bool = False, diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 2f0098b42..d8154ec31 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -9,7 +9,7 @@ import torch from tqdm import tqdm -from library import config_util +from library import config_util, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl from library import train_util from library import sdxl_train_util from library.config_util import ( @@ -17,42 +17,73 @@ BlueprintGenerator, ) from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging logger = logging.getLogger(__name__) +def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argparse.Namespace) -> None: + if is_flux: + _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) + else: + is_schnell = False + + if is_sd or is_sdxl: + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + elif is_sdxl: + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + else: + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) - # check cache latents arg - assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" + # assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" + args.cache_latents = True + args.cache_latents_to_disk = True use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - # tokenizerを準備する:datasetを動かすために必要 - if args.sdxl: - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - tokenizers = [tokenizer1, tokenizer2] + is_sd = not args.sdxl and not args.flux + is_sdxl = args.sdxl + is_flux = args.flux + + set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) + + if is_sd or is_sdxl: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(is_sd, True, args.vae_batch_size, args.skip_cache_check) else: - tokenizer = train_util.load_tokenizer(args) - tokenizers = [tokenizer] + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(True, args.vae_batch_size, args.skip_cache_check) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する + use_user_config = args.dataset_config is not None if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) @@ -83,17 +114,11 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) - - # datasetのcache_latentsを呼ばなければ、生の画像が返る - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args) # acceleratorを準備する logger.info("prepare accelerator") @@ -106,72 +131,27 @@ def cache_to_disk(args: argparse.Namespace) -> None: # モデルを読み込む logger.info("load model") - if args.sdxl: + if is_sd: + _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) + elif is_sdxl: (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) else: - _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) + vae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + + if is_sd or is_sdxl: + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(args.xformers) vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - # dataloaderを準備する - train_dataset_group.set_caching_mode("latents") - - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず - train_dataloader = accelerator.prepare(train_dataloader) - - # データ取得のためのループ - for batch in tqdm(train_dataloader): - b_size = len(batch["images"]) - vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size - flip_aug = batch["flip_aug"] - alpha_mask = batch["alpha_mask"] - random_crop = batch["random_crop"] - bucket_reso = batch["bucket_reso"] - - # バッチを分割して処理する - for i in range(0, b_size, vae_batch_size): - images = batch["images"][i : i + vae_batch_size] - absolute_paths = batch["absolute_paths"][i : i + vae_batch_size] - resized_sizes = batch["resized_sizes"][i : i + vae_batch_size] - - image_infos = [] - for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)): - image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) - image_info.image = image - image_info.bucket_reso = bucket_reso - image_info.resized_size = resized_size - image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz" - - if args.skip_existing: - if train_util.is_disk_cached_latents_is_expected( - image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask - ): - logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") - continue - - image_infos.append(image_info) - - if len(image_infos) > 0: - train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop) + # cache latents with dataset + # TODO use DataLoader to speed up + train_dataset_group.new_cache_latents(vae, accelerator) accelerator.wait_for_everyone() - accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + accelerator.print(f"Finished caching latents to disk.") def setup_parser() -> argparse.ArgumentParser: @@ -182,7 +162,11 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) config_util.add_config_arguments(parser) + parser.add_argument( + "--ae", type=str, default=None, help="Autoencoder model of FLUX to use / 使用するFLUXのオートエンコーダモデル" + ) parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") + parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") parser.add_argument( "--no_half_vae", action="store_true", @@ -191,7 +175,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--skip_existing", action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check." + " / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。", ) return parser diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index a75d9da74..d294d46c4 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -9,55 +9,68 @@ import torch from tqdm import tqdm -from library import config_util +from library import ( + config_util, + flux_train_utils, + flux_utils, + sdxl_model_util, + strategy_base, + strategy_flux, + strategy_sd, + strategy_sdxl, +) from library import train_util from library import sdxl_train_util +from library import utils from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) from library.utils import setup_logging, add_logging_arguments +from tools import cache_latents + setup_logging() import logging + logger = logging.getLogger(__name__) + def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) - # check cache arg - assert ( - args.cache_text_encoder_outputs_to_disk - ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります" - - # できるだけ準備はしておくが今のところSDXLのみしか動かない - assert ( - args.sdxl - ), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です" + args.cache_text_encoder_outputs = True + args.cache_text_encoder_outputs_to_disk = True use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - # tokenizerを準備する:datasetを動かすために必要 - if args.sdxl: - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - tokenizers = [tokenizer1, tokenizer2] - else: - tokenizer = train_util.load_tokenizer(args) - tokenizers = [tokenizer] + is_sd = not args.sdxl and not args.flux + is_sdxl = args.sdxl + is_flux = args.flux + + assert ( + is_sdxl or is_flux + ), "Cache text encoder outputs to disk is only supported for SDXL and FLUX models / テキストエンコーダ出力のディスクキャッシュはSDXLまたはFLUXでのみ有効です" + assert ( + is_sdxl or args.weighted_captions is None + ), "Weighted captions are only supported for SDXL models / 重み付きキャプションはSDXLモデルでのみ有効です" + + cache_latents.set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) # データセットを準備する + use_user_config = args.dataset_config is not None if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) @@ -88,15 +101,11 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args) # acceleratorを準備する logger.info("prepare accelerator") @@ -105,66 +114,68 @@ def cache_to_disk(args: argparse.Namespace) -> None: # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, _ = train_util.prepare_dtype(args) + t5xxl_dtype = utils.str_to_dtype(args.t5xxl_dtype, weight_dtype) # モデルを読み込む logger.info("load model") - if args.sdxl: - (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) + if is_sdxl: + _, text_encoder1, text_encoder2, _, _, _, _ = sdxl_train_util.load_target_model( + args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype + ) + text_encoder1.to(accelerator.device, weight_dtype) + text_encoder2.to(accelerator.device, weight_dtype) text_encoders = [text_encoder1, text_encoder2] else: - text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) - text_encoders = [text_encoder1] + clip_l = flux_utils.load_clip_l( + args.clip_l, weight_dtype, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors + ) + + t5xxl = flux_utils.load_t5xxl(args.t5xxl, None, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors) + + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + if t5xxl_dtype != t5xxl_dtype: + if t5xxl.dtype == torch.float8_e4m3fn and t5xxl_dtype.itemsize() >= 2: + logger.warning( + "The loaded model is fp8, but the specified T5XXL dtype is larger than fp8. This may cause a performance drop." + " / ロードされたモデルはfp8ですが、指定されたT5XXLのdtypeがfp8より高精度です。精度低下が発生する可能性があります。" + ) + logger.info(f"Casting T5XXL model to {t5xxl_dtype}") + t5xxl.to(t5xxl_dtype) + + text_encoders = [clip_l, t5xxl] for text_encoder in text_encoders: - text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) text_encoder.eval() - # dataloaderを準備する - train_dataset_group.set_caching_mode("text") - - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) + # build text encoder outputs caching strategy + if is_sdxl: + text_encoder_outputs_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions + ) + else: + text_encoder_outputs_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=False, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) + + # build text encoding strategy + if is_sdxl: + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + else: + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) - # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず - train_dataloader = accelerator.prepare(train_dataloader) - - # データ取得のためのループ - for batch in tqdm(train_dataloader): - absolute_paths = batch["absolute_paths"] - input_ids1_list = batch["input_ids1_list"] - input_ids2_list = batch["input_ids2_list"] - - image_infos = [] - for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list): - image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) - image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX - image_info - - if args.skip_existing: - if os.path.exists(image_info.text_encoder_outputs_npz): - logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") - continue - - image_info.input_ids1 = input_ids1 - image_info.input_ids2 = input_ids2 - image_infos.append(image_info) - - if len(image_infos) > 0: - b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos]) - b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos]) - train_util.cache_batch_text_encoder_outputs( - image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype - ) + # cache text encoder outputs + train_dataset_group.new_cache_text_encoder_outputs(text_encoders, accelerator) accelerator.wait_for_everyone() accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") @@ -179,11 +190,20 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_dataset_arguments(parser, True, True, True) config_util.add_config_arguments(parser) sdxl_train_util.add_sdxl_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") + parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") + parser.add_argument( + "--t5xxl_dtype", + type=str, + default=None, + help="T5XXL model dtype, default: None (use mixed precision dtype) / T5XXLモデルのdtype, デフォルト: None (mixed precisionのdtypeを使用)", + ) parser.add_argument( "--skip_existing", action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check." + " / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。", ) return parser From 2244cf5b835cc35179f29b1babb4a2d19f54bfae Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 13 Oct 2024 18:22:19 +0900 Subject: [PATCH 179/348] load images in parallel when caching latents --- library/train_util.py | 93 ++++++++++++++++++++++++------------------- 1 file changed, 53 insertions(+), 40 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 4e6b3408d..1db470d63 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3,6 +3,7 @@ import argparse import ast import asyncio +from concurrent.futures import Future, ThreadPoolExecutor import datetime import importlib import json @@ -1058,7 +1059,6 @@ def __eq__(self, other): and self.random_crop == other.random_crop ) - batches: List[Tuple[Condition, List[ImageInfo]]] = [] batch: List[ImageInfo] = [] current_condition = None @@ -1066,57 +1066,70 @@ def __eq__(self, other): num_processes = accelerator.num_processes process_index = accelerator.process_index - logger.info("checking cache validity...") - for i, info in enumerate(tqdm(image_infos)): - subset = self.image_to_subset[info.image_key] + # define a function to submit a batch to cache + def submit_batch(batch, cond): + for info in batch: + if info.image is not None and isinstance(info.image, Future): + info.image = info.image.result() # future to image + caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop) - if info.latents_npz is not None: # fine tuning dataset - continue + # define ThreadPoolExecutor to load images in parallel + max_workers = min(os.cpu_count(), len(image_infos)) + max_workers = max(1, max_workers // num_processes) # consider multi-gpu + max_workers = min(max_workers, caching_strategy.batch_size) # max_workers should be less than batch_size + executor = ThreadPoolExecutor(max_workers) - # check disk cache exists and size of latents - if caching_strategy.cache_to_disk: - # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix - info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) + try: + # iterate images + logger.info("caching latents...") + for i, info in enumerate(tqdm(image_infos)): + subset = self.image_to_subset[info.image_key] - # if the modulo of num_processes is not equal to process_index, skip caching - # this makes each process cache different latents - if i % num_processes != process_index: + if info.latents_npz is not None: # fine tuning dataset continue - # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) - cache_available = caching_strategy.is_disk_cached_latents_expected( - info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask - ) - if cache_available: # do not add to batch - continue + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different latents + if i % num_processes != process_index: + continue - # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty - condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) - if len(batch) > 0 and current_condition != condition: - batches.append((current_condition, batch)) - batch = [] + # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") - batch.append(info) - current_condition = condition + cache_available = caching_strategy.is_disk_cached_latents_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) + if cache_available: # do not add to batch + continue - # if number of data in batch is enough, flush the batch - if len(batch) >= caching_strategy.batch_size: - batches.append((current_condition, batch)) - batch = [] - current_condition = None + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + submit_batch(batch, current_condition) + batch = [] - if len(batch) > 0: - batches.append((current_condition, batch)) + if info.image is None: + # load image in parallel + info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask) - if len(batches) == 0: - logger.info("no latents to cache") - return + batch.append(info) + current_condition = condition - # iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded - logger.info("caching latents...") - for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): - caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) + # if number of data in batch is enough, flush the batch + if len(batch) >= caching_strategy.batch_size: + submit_batch(batch, current_condition) + batch = [] + current_condition = None + + if len(batch) > 0: + submit_batch(batch, current_condition) + + finally: + executor.shutdown() def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと From bfc3a65acda7f90abef9c16db279d2952f73fb77 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 13 Oct 2024 19:08:16 +0900 Subject: [PATCH 180/348] fix to work cache latents/text encoder outputs --- library/train_util.py | 11 +++++++---- tools/cache_latents.py | 11 ++++++----- tools/cache_text_encoder_outputs.py | 18 +++++++++++++----- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 1db470d63..926609267 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4064,15 +4064,18 @@ def verify_command_line_training_args(args: argparse.Namespace): ) +def enable_high_vram(args: argparse.Namespace): + if args.highvram: + logger.info("highvram is enabled / highvramが有効です") + global HIGH_VRAM + HIGH_VRAM = True + def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable 学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する """ - if args.highvram: - print("highvram is enabled / highvramが有効です") - global HIGH_VRAM - HIGH_VRAM = True + enable_high_vram(args) if args.v_parameterization and not args.v2: logger.warning( diff --git a/tools/cache_latents.py b/tools/cache_latents.py index d8154ec31..e2faa58a7 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -9,7 +9,7 @@ import torch from tqdm import tqdm -from library import config_util, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl +from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl from library import train_util from library import sdxl_train_util from library.config_util import ( @@ -30,7 +30,7 @@ def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argpa else: is_schnell = False - if is_sd or is_sdxl: + if is_sd: tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) elif is_sdxl: tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) @@ -51,6 +51,7 @@ def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argpa def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) + train_util.enable_high_vram(args) # assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" args.cache_latents = True @@ -161,10 +162,10 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) - parser.add_argument( - "--ae", type=str, default=None, help="Autoencoder model of FLUX to use / 使用するFLUXのオートエンコーダモデル" - ) + flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") parser.add_argument( diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index d294d46c4..7be9ad781 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -27,7 +27,7 @@ BlueprintGenerator, ) from library.utils import setup_logging, add_logging_arguments -from tools import cache_latents +from cache_latents import set_tokenize_strategy setup_logging() import logging @@ -38,6 +38,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) + train_util.enable_high_vram(args) args.cache_text_encoder_outputs = True args.cache_text_encoder_outputs_to_disk = True @@ -57,8 +58,8 @@ def cache_to_disk(args: argparse.Namespace) -> None: assert ( is_sdxl or args.weighted_captions is None ), "Weighted captions are only supported for SDXL models / 重み付きキャプションはSDXLモデルでのみ有効です" - - cache_latents.set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) + + set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) # データセットを準備する use_user_config = args.dataset_config is not None @@ -178,7 +179,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: train_dataset_group.new_cache_text_encoder_outputs(text_encoders, accelerator) accelerator.wait_for_everyone() - accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + accelerator.print(f"Finished caching text encoder outputs to disk.") def setup_parser() -> argparse.ArgumentParser: @@ -188,9 +189,10 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) - sdxl_train_util.add_sdxl_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") parser.add_argument( @@ -205,6 +207,12 @@ def setup_parser() -> argparse.ArgumentParser: help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check." " / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。", ) + parser.add_argument( + "--weighted_captions", + action="store_true", + default=False, + help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", + ) return parser From 2d5f7fa709c31d07a1bb44b5be391c29b77d3cfc Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 13 Oct 2024 19:23:21 +0900 Subject: [PATCH 181/348] update README --- README.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 544c665de..7fae50d1a 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,15 @@ The command to install PyTorch is as follows: ### Recent Updates -Oct 12, 2024 (update 1): +Oct 13, 2024: + +- Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large. - During multi-GPU training, caching of latents and Text Encoder outputs is now done in multi-GPU. -- `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching. + - Please make sure that `--highvram` and `--vae_batch_size` are specified correctly. If you have enough VRAM, you can increase the batch size to speed up the caching. + - `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching. + - Multi-threading is also implemented for caching of latents. This may speed up the caching process about 5% (depends on the environment). + - `tools/cache_latents.py` and `tools/cache_text_encoder_outputs.py` also have been updated to support multi-GPU caching. - `--skip_cache_check` option is added to each training script. - When specified, the consistency check of the cache file `*.npz` contents (e.g., image size and flip for latents, mask for Text Encoder outputs) is skipped. - Specify this option if you have a large number of cache files and the consistency check takes time. From 2500f5a79806fdbe74c43db24a95ee19329a8fcc Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 15 Oct 2024 07:16:34 +0900 Subject: [PATCH 182/348] fix latents caching not working closes #1696 --- fine_tune.py | 2 +- flux_train.py | 2 +- sd3_train.py | 2 +- sdxl_train.py | 2 +- sdxl_train_control_net.py | 2 +- train_db.py | 2 +- train_textual_inversion.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index cdc005d9a..0b7cc5100 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -177,7 +177,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) diff --git a/flux_train.py b/flux_train.py index 46a8babdb..91ae3af57 100644 --- a/flux_train.py +++ b/flux_train.py @@ -190,7 +190,7 @@ def train(args): ae.requires_grad_(False) ae.eval() - train_dataset_group.new_cache_latents(ae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(ae, accelerator) ae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) diff --git a/sd3_train.py b/sd3_train.py index 7290956ad..ef18c32c4 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -243,7 +243,7 @@ def train(args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) diff --git a/sdxl_train.py b/sdxl_train.py index 9b2d19165..79a2fbb6e 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -272,7 +272,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 74b3a64a4..24080afbd 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -209,7 +209,7 @@ def unwrap_model(model): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) diff --git a/train_db.py b/train_db.py index 683b42332..4a58e27b0 100644 --- a/train_db.py +++ b/train_db.py @@ -156,7 +156,7 @@ def train(args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 4d8a3abbf..77b5d717a 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -378,7 +378,7 @@ def train(self, args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() From 3cc5b8db99c66b9e205c4fd4a5f969090c51ef58 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 18 Oct 2024 20:57:13 +0900 Subject: [PATCH 183/348] Diff Output Preserv loss for SDXL --- library/config_util.py | 17 +++++++---------- library/train_util.py | 17 ++++++++++++++++- sdxl_train_network.py | 20 +++++++++++++++++++- train_network.py | 35 +++++++++++++++++++++++++---------- 4 files changed, 67 insertions(+), 22 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index f8cdfe60a..fc1fbf46d 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -10,13 +10,7 @@ from pathlib import Path # from toolz import curry -from typing import ( - List, - Optional, - Sequence, - Tuple, - Union, -) +from typing import Dict, List, Optional, Sequence, Tuple, Union import toml import voluptuous @@ -78,6 +72,7 @@ class BaseSubsetParams: caption_tag_dropout_rate: float = 0.0 token_warmup_min: int = 1 token_warmup_step: float = 0 + custom_attributes: Optional[Dict[str, Any]] = None @dataclass @@ -197,6 +192,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "token_warmup_step": Any(float, int), "caption_prefix": str, "caption_suffix": str, + "custom_attributes": dict, } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -538,9 +534,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu flip_aug: {subset.flip_aug} face_crop_aug_range: {subset.face_crop_aug_range} random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - alpha_mask: {subset.alpha_mask}, + token_warmup_min: {subset.token_warmup_min} + token_warmup_step: {subset.token_warmup_step} + alpha_mask: {subset.alpha_mask} + custom_attributes: {subset.custom_attributes} """ ), " ", diff --git a/library/train_util.py b/library/train_util.py index 4a446e81c..7d3fce5b2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -396,6 +396,7 @@ def __init__( caption_suffix: Optional[str], token_warmup_min: int, token_warmup_step: Union[float, int], + custom_attributes: Optional[Dict[str, Any]] = None, ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -419,6 +420,8 @@ def __init__( self.token_warmup_min = token_warmup_min # step=0におけるタグの数 self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる + self.custom_attributes = custom_attributes if custom_attributes is not None else {} + self.img_count = 0 @@ -449,6 +452,7 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -473,6 +477,7 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes=custom_attributes, ) self.is_reg = is_reg @@ -512,6 +517,7 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -536,6 +542,7 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes=custom_attributes, ) self.metadata_file = metadata_file @@ -1474,11 +1481,14 @@ def __getitem__(self, index): target_sizes_hw = [] flippeds = [] # 変数名が微妙 text_encoder_outputs_list = [] + custom_attributes = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] + custom_attributes.append(subset.custom_attributes) + # in case of fine tuning, is_reg is always False loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) @@ -1646,7 +1656,9 @@ def none_or_stack_elements(tensors_list, converter): return None return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))] + # set example example = {} + example["custom_attributes"] = custom_attributes # may be list of empty dict example["loss_weights"] = torch.FloatTensor(loss_weights) example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor) example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x) @@ -2630,7 +2642,9 @@ def debug_dataset(train_dataset, show_input_ids=False): f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' ) if "network_multipliers" in example: - print(f"network multiplier: {example['network_multipliers'][j]}") + logger.info(f"network multiplier: {example['network_multipliers'][j]}") + if "custom_attributes" in example: + logger.info(f"custom attributes: {example['custom_attributes'][j]}") # if show_input_ids: # logger.info(f"input ids: {iid}") @@ -4091,6 +4105,7 @@ def enable_high_vram(args: argparse.Namespace): global HIGH_VRAM HIGH_VRAM = True + def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 4a16a4891..d45df6e05 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,4 +1,5 @@ import argparse +from typing import List, Optional import torch from accelerate import Accelerator @@ -172,7 +173,18 @@ def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, wei return encoder_hidden_states1, encoder_hidden_states2, pool2 - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + def call_unet( + self, + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_conds, + batch, + weight_dtype, + indices: Optional[List[int]] = None, + ): noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype # get size embeddings @@ -186,6 +198,12 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + if indices is not None and len(indices) > 0: + noisy_latents = noisy_latents[indices] + timesteps = timesteps[indices] + text_embedding = text_embedding[indices] + vector_embedding = vector_embedding[indices] + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred diff --git a/train_network.py b/train_network.py index d5330aef4..ef766737d 100644 --- a/train_network.py +++ b/train_network.py @@ -143,7 +143,7 @@ def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, tex for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs): noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred @@ -218,6 +218,30 @@ def get_noise_pred_and_target( else: target = noise + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + noise_pred_prior = self.call_unet( + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_encoder_conds, + batch, + weight_dtype, + indices=diff_output_pr_indices, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) + return noise_pred, target, timesteps, huber_c, None def post_process_loss(self, loss, args, timesteps, noise_scheduler): @@ -1123,15 +1147,6 @@ def remove_model(old_ckpt_name): with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: - # # SD only - # encoded_text_encoder_conds = get_weighted_text_embeddings( - # tokenizers[0], - # text_encoder, - # batch["captions"], - # accelerator.device, - # args.max_token_length // 75 if args.max_token_length else 1, - # clip_skip=args.clip_skip, - # ) input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( tokenize_strategy, From d8d7142665a8f6b2d43827c9b3a6a2de009c09cb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 18 Oct 2024 23:16:30 +0900 Subject: [PATCH 184/348] fix to work caching latents #1696 --- sdxl_train_control_net_lllite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 14ff7c240..913b1d435 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -181,7 +181,7 @@ def train(args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) From ef70aa7b42b5c923cc1a8594b2f30487a2b4f700 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Fri, 18 Oct 2024 23:39:48 +0900 Subject: [PATCH 185/348] add FLUX.1 support --- README.md | 19 +++++++ flux_train_network.py | 123 ++++++++++++++++++++++++++++-------------- 2 files changed, 103 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 7fae50d1a..59f70ebcd 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,25 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 19, 2024: + +- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. + - A method to make the output of LoRA closer to the output when LoRA is not applied, with captions that do not contain trigger words. + - Define a Dataset subset for the regularization image (`is_reg = true`) with `.toml`. Add `custom_attributes.diff_output_preservation = true`. + - See [dataset configuration](docs/config_README-en.md) for the regularization dataset. + - Specify "number of training images x number of epochs >= number of regularization images x number of epochs". + - Specify a large value for `--prior_loss_weight` option (not dataset config). We recommend 10-1000. + - Set the loss in the training without using the regularization image to be close to the loss in the training using DOP. +``` +[[datasets.subsets]] +image_dir = "path/to/image/dir" +num_repeats = 1 +is_reg = true +custom_attributes.diff_output_preservation = true # Add this +``` + + + Oct 13, 2024: - Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large. diff --git a/flux_train_network.py b/flux_train_network.py index aa92fe3ae..8431a6dc9 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -373,33 +373,13 @@ def get_noise_pred_and_target( if not args.apply_t5_attn_mask: t5_attn_mask = None - if not args.split_mode: - # normal forward - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = unet( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - else: - # split forward to reduce memory usage - assert network.train_blocks == "single", "train_blocks must be single for split mode" - with accelerator.autocast(): - # move flux lower to cpu, and then move flux upper to gpu - unet.to("cpu") - clean_memory_on_device(accelerator.device) - self.flux_upper.to(accelerator.device) - - # upper model does not require grad - with torch.no_grad(): - intermediate_img, intermediate_txt, vec, pe = self.flux_upper( - img=packed_noisy_model_input, + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): + if not args.split_mode: + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=img, img_ids=img_ids, txt=t5_out, txt_ids=txt_ids, @@ -408,18 +388,52 @@ def get_noise_pred_and_target( guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) - - # move flux upper back to cpu, and then move flux lower to gpu - self.flux_upper.to("cpu") - clean_memory_on_device(accelerator.device) - unet.to(accelerator.device) - - # lower model requires grad - intermediate_img.requires_grad_(True) - intermediate_txt.requires_grad_(True) - vec.requires_grad_(True) - pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + else: + # split forward to reduce memory usage + assert network.train_blocks == "single", "train_blocks must be single for split mode" + with accelerator.autocast(): + # move flux lower to cpu, and then move flux upper to gpu + unet.to("cpu") + clean_memory_on_device(accelerator.device) + self.flux_upper.to(accelerator.device) + + # upper model does not require grad + with torch.no_grad(): + intermediate_img, intermediate_txt, vec, pe = self.flux_upper( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # move flux upper back to cpu, and then move flux lower to gpu + self.flux_upper.to("cpu") + clean_memory_on_device(accelerator.device) + unet.to(accelerator.device) + + # lower model requires grad + intermediate_img.requires_grad_(True) + intermediate_txt.requires_grad_(True) + vec.requires_grad_(True) + pe.requires_grad_(True) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + + return model_pred + + model_pred = call_dit( + img=packed_noisy_model_input, + img_ids=img_ids, + t5_out=t5_out, + txt_ids=txt_ids, + l_pooled=l_pooled, + timesteps=timesteps, + guidance_vec=guidance_vec, + t5_attn_mask=t5_attn_mask, + ) # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) @@ -430,6 +444,37 @@ def get_noise_pred_and_target( # flow matching loss: this is different from SD3 target = noise - latents + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + model_pred_prior = call_dit( + img=packed_noisy_model_input[diff_output_pr_indices], + img_ids=img_ids[diff_output_pr_indices], + t5_out=t5_out[diff_output_pr_indices], + txt_ids=txt_ids[diff_output_pr_indices], + l_pooled=l_pooled[diff_output_pr_indices], + timesteps=timesteps[diff_output_pr_indices], + guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, + t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) + model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + args, + model_pred_prior, + noisy_model_input[diff_output_pr_indices], + sigmas[diff_output_pr_indices] if sigmas is not None else None, + ) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + return model_pred, target, timesteps, None, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): From 2c45d979e696fd4412ae1336feaee3bc9b967af4 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 19 Oct 2024 19:21:12 +0900 Subject: [PATCH 186/348] update README, remove unnecessary autocast --- README.md | 10 ++++------ flux_train_network.py | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 59f70ebcd..32ee38573 100644 --- a/README.md +++ b/README.md @@ -13,13 +13,13 @@ The command to install PyTorch is as follows: Oct 19, 2024: -- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. +- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. SD1/2 is not tested yet. This is an experimental feature. - A method to make the output of LoRA closer to the output when LoRA is not applied, with captions that do not contain trigger words. - Define a Dataset subset for the regularization image (`is_reg = true`) with `.toml`. Add `custom_attributes.diff_output_preservation = true`. - See [dataset configuration](docs/config_README-en.md) for the regularization dataset. - - Specify "number of training images x number of epochs >= number of regularization images x number of epochs". - - Specify a large value for `--prior_loss_weight` option (not dataset config). We recommend 10-1000. - - Set the loss in the training without using the regularization image to be close to the loss in the training using DOP. + - Specify "number of training images x number of repeats >= number of regularization images x number of repeats". + - Specify a large value for `--prior_loss_weight` option (not dataset config). The appropriate value is unknown, but try around 10-100. Note that the default is 1.0. + - You may want to start with 2/3 to 3/4 of the loss value when DOP is not applied. If it is 1/2, DOP may not be working. ``` [[datasets.subsets]] image_dir = "path/to/image/dir" @@ -28,8 +28,6 @@ is_reg = true custom_attributes.diff_output_preservation = true # Add this ``` - - Oct 13, 2024: - Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large. diff --git a/flux_train_network.py b/flux_train_network.py index 8431a6dc9..9cc8811b5 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -453,7 +453,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t if len(diff_output_pr_indices) > 0: network.set_multiplier(0.0) - with torch.no_grad(), accelerator.autocast(): + with torch.no_grad(): model_pred_prior = call_dit( img=packed_noisy_model_input[diff_output_pr_indices], img_ids=img_ids[diff_output_pr_indices], From 7fe8e162cb54ccf259eead1cca0ebdcc4e2b77fe Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Oct 2024 08:45:27 +0900 Subject: [PATCH 187/348] fix to work ControlNetSubset with custom_attributes --- library/train_util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 7d3fce5b2..462c7a9a2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -578,6 +578,7 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -602,6 +603,7 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes=custom_attributes, ) self.conditioning_data_dir = conditioning_data_dir From 138dac4aea57716e2f23580305f6e40836a87228 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Oct 2024 09:22:38 +0900 Subject: [PATCH 188/348] update README --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 32ee38573..532c3368f 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,9 @@ Oct 19, 2024: - Define a Dataset subset for the regularization image (`is_reg = true`) with `.toml`. Add `custom_attributes.diff_output_preservation = true`. - See [dataset configuration](docs/config_README-en.md) for the regularization dataset. - Specify "number of training images x number of repeats >= number of regularization images x number of repeats". - - Specify a large value for `--prior_loss_weight` option (not dataset config). The appropriate value is unknown, but try around 10-100. Note that the default is 1.0. - - You may want to start with 2/3 to 3/4 of the loss value when DOP is not applied. If it is 1/2, DOP may not be working. + - The weights of DOP is specified by `--prior_loss_weight` option (not dataset config). + - The appropriate value is still unknown. For FLUX, according to the comments in the [PR](https://github.com/kohya-ss/sd-scripts/pull/1710), the value may be 1 (thanks to dxqbYD!). For SDXL, a larger value may be needed (10-100 may be good starting points). + - It may be good to adjust the value so that the loss is about half to three-quarters of the loss when DOP is not applied. ``` [[datasets.subsets]] image_dir = "path/to/image/dir" @@ -28,6 +29,7 @@ is_reg = true custom_attributes.diff_output_preservation = true # Add this ``` + Oct 13, 2024: - Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large. From 623017f71695bcee18f36f5a1f57514974d9350d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 24 Oct 2024 19:49:28 +0900 Subject: [PATCH 189/348] refactor SD3 CLIP to transformers etc. --- flux_train.py | 4 +- flux_train_network.py | 2 +- library/flux_train_utils.py | 3 +- library/flux_utils.py | 59 +-- library/sai_model_spec.py | 9 +- library/sd3_models.py | 1000 ++--------------------------------- library/sd3_train_utils.py | 244 +++++---- library/sd3_utils.py | 503 +++++------------- library/strategy_sd3.py | 184 ++++--- library/train_util.py | 31 ++ library/utils.py | 42 +- sd3_minimal_inference.py | 390 +++++++------- sd3_train.py | 738 ++++++++++++++------------ 13 files changed, 1130 insertions(+), 2079 deletions(-) diff --git a/flux_train.py b/flux_train.py index 91ae3af57..79c44d7b4 100644 --- a/flux_train.py +++ b/flux_train.py @@ -29,7 +29,7 @@ from accelerate.utils import set_seed from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux -from library.sd3_train_utils import load_prompts, FlowMatchEulerDiscreteScheduler +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler import library.train_util as train_util @@ -241,7 +241,7 @@ def train(args): text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - prompts = load_prompts(args.sample_prompts) + prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in prompts: diff --git a/flux_train_network.py b/flux_train_network.py index 9cc8811b5..cffeb3b19 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -231,7 +231,7 @@ def cache_text_encoder_outputs_if_needed( tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - prompts = sd3_train_utils.load_prompts(args.sample_prompts) + prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in prompts: diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index b3c9184f2..fa673a2f0 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -15,7 +15,6 @@ from safetensors.torch import save_file from library import flux_models, flux_utils, strategy_base, train_util -from library.sd3_train_utils import load_prompts from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -70,7 +69,7 @@ def sample_images( text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) - prompts = load_prompts(args.sample_prompts) + prompts = train_util.load_prompts(args.sample_prompts) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) diff --git a/library/flux_utils.py b/library/flux_utils.py index 7a1ec37b8..86a2ec600 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -10,40 +10,21 @@ from accelerate import init_empty_weights from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config -from library import flux_models - -from library.utils import setup_logging, MemoryEfficientSafeOpen +from library.utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) +from library import flux_models +from library.utils import load_safetensors + MODEL_VERSION_FLUX_V1 = "flux1" MODEL_NAME_DEV = "dev" MODEL_NAME_SCHNELL = "schnell" -# temporary copy from sd3_utils TODO refactor -def load_safetensors( - path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 -): - if disable_mmap: - # return safetensors.torch.load(open(path, "rb").read()) - # use experimental loader - logger.info(f"Loading without mmap (experimental)") - state_dict = {} - with MemoryEfficientSafeOpen(path) as f: - for key in f.keys(): - state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) - return state_dict - else: - try: - return load_file(path, device=device) - except: - return load_file(path) # prevent device invalid Error - - def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: """ チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。 @@ -161,8 +142,14 @@ def load_ae( return ae -def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> CLIPTextModel: - logger.info("Building CLIP") +def load_clip_l( + ckpt_path: Optional[str], + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> CLIPTextModel: + logger.info("Building CLIP-L") CLIPL_CONFIG = { "_name_or_path": "clip-vit-large-patch14/", "architectures": ["CLIPModel"], @@ -255,15 +242,22 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev with init_empty_weights(): clip = CLIPTextModel._from_config(config) - logger.info(f"Loading state dict from {ckpt_path}") - sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = clip.load_state_dict(sd, strict=False, assign=True) - logger.info(f"Loaded CLIP: {info}") + logger.info(f"Loaded CLIP-L: {info}") return clip def load_t5xxl( - ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False + ckpt_path: str, + dtype: Optional[torch.dtype], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, ) -> T5EncoderModel: T5_CONFIG_JSON = """ { @@ -303,8 +297,11 @@ def load_t5xxl( with init_empty_weights(): t5xxl = T5EncoderModel._from_config(config) - logger.info(f"Loading state dict from {ckpt_path}") - sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = t5xxl.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded T5xxl: {info}") return t5xxl diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index ad72ec00d..8896c047e 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -57,8 +57,8 @@ ARCH_SD_V2_512 = "stable-diffusion-v2-512" ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" -ARCH_SD3_M = "stable-diffusion-3-medium" -ARCH_SD3_UNKNOWN = "stable-diffusion-3" +ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc. +# ARCH_SD3_UNKNOWN = "stable-diffusion-3" ARCH_FLUX_1_DEV = "flux-1-dev" ARCH_FLUX_1_UNKNOWN = "flux-1" @@ -140,10 +140,7 @@ def build_metadata( if sdxl: arch = ARCH_SD_XL_V1_BASE elif sd3 is not None: - if sd3 == "m": - arch = ARCH_SD3_M - else: - arch = ARCH_SD3_UNKNOWN + arch = ARCH_SD3_M + "-" + sd3 elif flux is not None: if flux == "dev": arch = ARCH_FLUX_1_DEV diff --git a/library/sd3_models.py b/library/sd3_models.py index ec704dcba..c81aa4794 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -4,6 +4,7 @@ # and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! from ast import Tuple +from dataclasses import dataclass from functools import partial import math from types import SimpleNamespace @@ -15,6 +16,7 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from transformers import CLIPTokenizer, T5TokenizerFast + from .utils import setup_logging setup_logging() @@ -35,139 +37,21 @@ memory_efficient_attention = None -# region tokenizer -class SDTokenizer: - def __init__( - self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None - ): - """ - サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。 - Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings. - """ - self.tokenizer: CLIPTokenizer = tokenizer - self.max_length = max_length - self.min_length = min_length - empty = self.tokenizer("")["input_ids"] - if has_start_token: - self.tokens_start = 1 - self.start_token = empty[0] - self.end_token = empty[1] - else: - self.tokens_start = 0 - self.start_token = None - self.end_token = empty[0] - self.pad_with_end = pad_with_end - self.pad_to_max_length = pad_to_max_length - vocab = self.tokenizer.get_vocab() - self.inv_vocab = {v: k for k, v in vocab.items()} - self.max_word_length = 8 - - def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: - """ - Tokenize the text without weights. - """ - if type(text) == str: - text = [text] - batch_tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt") - # return tokens["input_ids"] - - pad_token = self.end_token if self.pad_with_end else 0 - for tokens in batch_tokens["input_ids"]: - assert tokens[0] == self.start_token, f"tokens[0]: {tokens[0]}, start_token: {self.start_token}" - - def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None): - """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. - The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" - """ - ja: テキストをトークン化し、重み値を持ちます - すべての値に1.0を仮定し、他の機能を無視します。 - 詳細は参考実装には関係なく、重み自体はSD3に対して弱い影響しかありません。へぇ~ - """ - if self.pad_with_end: - pad_token = self.end_token - else: - pad_token = 0 - batch = [] - if self.start_token is not None: - batch.append((self.start_token, 1.0)) - to_tokenize = text.replace("\n", " ").split(" ") - to_tokenize = [x for x in to_tokenize if x != ""] - for word in to_tokenize: - batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]]) - batch.append((self.end_token, 1.0)) - print(len(batch), self.max_length, self.min_length) - if self.pad_to_max_length: - batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) - if self.min_length is not None and len(batch) < self.min_length: - batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) - - # truncate to max_length - print( - f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}" - ) - if truncate_to_max_length and len(batch) > self.max_length: - batch = batch[: self.max_length] - if truncate_length is not None and len(batch) > truncate_length: - batch = batch[:truncate_length] - - return [batch] - - -class T5XXLTokenizer(SDTokenizer): - """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" - - def __init__(self): - super().__init__( - pad_with_end=False, - tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), - has_start_token=False, - pad_to_max_length=False, - max_length=99999999, - min_length=77, - ) - - -class SDXLClipGTokenizer(SDTokenizer): - def __init__(self, tokenizer): - super().__init__(pad_with_end=False, tokenizer=tokenizer) - - -class SD3Tokenizer: - def __init__(self, t5xxl=True, t5xxl_max_length: Optional[int] = 256): - if t5xxl_max_length is None: - t5xxl_max_length = 256 - - # TODO cache tokenizer settings locally or hold them in the repo like ComfyUI - clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) - self.clip_g = SDXLClipGTokenizer(clip_tokenizer) - # self.clip_l = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - # self.clip_g = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") - self.t5xxl = T5XXLTokenizer() if t5xxl else None - # t5xxl has 99999999 max length, clip has 77 - self.t5xxl_max_length = t5xxl_max_length - - def tokenize_with_weights(self, text: str): - return ( - self.clip_l.tokenize_with_weights(text), - self.clip_g.tokenize_with_weights(text), - ( - self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.t5xxl_max_length) - if self.t5xxl is not None - else None - ), - ) - - def tokenize(self, text: str): - return ( - self.clip_l.tokenize(text), - self.clip_g.tokenize(text), - (self.t5xxl.tokenize(text) if self.t5xxl is not None else None), - ) - +# region mmdit -# endregion -# region mmdit +@dataclass +class SD3Params: + patch_size: int + depth: int + num_patches: int + pos_embed_max_size: int + adm_in_channels: int + qk_norm: Optional[str] + x_block_self_attn_layers: List[int] + context_embedder_in_features: int + context_embedder_out_features: int + model_type: str def get_2d_sincos_pos_embed( @@ -286,10 +170,6 @@ def timestep_embedding(t, dim, max_period=10000): return embedding -def rmsnorm(x, eps=1e-6): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - - class PatchEmbed(nn.Module): def __init__( self, @@ -301,8 +181,9 @@ def __init__( flatten=True, bias=True, strict_img_size=True, - dynamic_img_pad=True, + dynamic_img_pad=False, ): + # dynamic_img_pad and norm is omitted in SD3.5 super().__init__() self.patch_size = patch_size self.flatten = flatten @@ -432,6 +313,10 @@ def forward(self, x): return self.mlp(x) +def rmsnorm(x, eps=1e-6): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + class RMSNorm(torch.nn.Module): def __init__( self, @@ -604,53 +489,6 @@ def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"): return scores -class SelfAttention(AttentionLinears): - def __init__(self, dim, num_heads=8, mode="xformers"): - super().__init__(dim, num_heads, qkv_bias=True, pre_only=False) - assert mode in MEMORY_LAYOUTS - self.head_dim = dim // num_heads - self.attn_mode = mode - - def set_attn_mode(self, mode): - self.attn_mode = mode - - def forward(self, x): - q, k, v = self.pre_attention(x) - attn_score = attention(q, k, v, self.head_dim, mode=self.attn_mode) - return self.post_attention(attn_score) - - -class TransformerBlock(nn.Module): - def __init__(self, context_size, mode="xformers"): - super().__init__() - self.context_size = context_size - self.norm1 = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) - self.attn = SelfAttention(context_size, mode=mode) - self.norm2 = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) - self.mlp = MLP( - in_features=context_size, - hidden_features=context_size * 4, - act_layer=lambda: nn.GELU(approximate="tanh"), - ) - - def forward(self, x): - x = x + self.attn(self.norm1(x)) - x = x + self.mlp(self.norm2(x)) - return x - - -class Transformer(nn.Module): - def __init__(self, context_size, num_layers, mode="xformers"): - super().__init__() - self.layers = nn.ModuleList([TransformerBlock(context_size, mode) for _ in range(num_layers)]) - self.norm = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return self.norm(x) - - # DismantledBlock in mmdit.py class SingleDiTBlock(nn.Module): """ @@ -823,7 +661,8 @@ def __init__( mlp_ratio: float = 4.0, learn_sigma: bool = False, adm_in_channels: Optional[int] = None, - context_embedder_config: Optional[Dict] = None, + context_embedder_in_features: Optional[int] = None, + context_embedder_out_features: Optional[int] = None, use_checkpoint: bool = False, register_length: int = 0, attn_mode: str = "torch", @@ -837,10 +676,10 @@ def __init__( num_patches=None, qk_norm: Optional[str] = None, qkv_bias: bool = True, - context_processor_layers=None, - context_size=4096, + model_type: str = "sd3m", ): super().__init__() + self._model_type = model_type self.learn_sigma = learn_sigma self.in_channels = in_channels default_out_channels = in_channels * 2 if learn_sigma else in_channels @@ -875,12 +714,11 @@ def __init__( assert isinstance(adm_in_channels, int) self.y_embedder = Embedder(adm_in_channels, self.hidden_size) - if context_processor_layers is not None: - self.context_processor = Transformer(context_size, context_processor_layers, attn_mode) + if context_embedder_in_features is not None: + self.context_embedder = nn.Linear(context_embedder_in_features, context_embedder_out_features) else: - self.context_processor = None + self.context_embedder = nn.Identity() - self.context_embedder = nn.Linear(context_size, self.hidden_size) self.register_length = register_length if self.register_length > 0: self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size)) @@ -922,7 +760,7 @@ def __init__( @property def model_type(self): - return "m" # only support medium + return self._model_type @property def device(self): @@ -1024,9 +862,6 @@ def forward( y: (N, D) tensor of class labels """ - if self.context_processor is not None: - context = self.context_processor(context) - B, C, H, W = x.shape x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device).to(dtype=x.dtype) c = self.t_embedder(t, dtype=x.dtype) # (N, D) @@ -1052,22 +887,21 @@ def forward( return x[:, :, :H, :W] -def create_mmdit_sd3_medium_configs(attn_mode: str): - # {'patch_size': 2, 'depth': 24, 'num_patches': 36864, - # 'pos_embed_max_size': 192, 'adm_in_channels': 2048, 'context_embedder': - # {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}}} +def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT: mmdit = MMDiT( input_size=None, - pos_embed_max_size=192, - patch_size=2, + pos_embed_max_size=params.pos_embed_max_size, + patch_size=params.patch_size, in_channels=16, - adm_in_channels=2048, - depth=24, + adm_in_channels=params.adm_in_channels, + context_embedder_in_features=params.context_embedder_in_features, + context_embedder_out_features=params.context_embedder_out_features, + depth=params.depth, mlp_ratio=4, - qk_norm=None, - num_patches=36864, - context_size=4096, + qk_norm=params.qk_norm, + num_patches=params.num_patches, attn_mode=attn_mode, + model_type=params.model_type, ) return mmdit @@ -1075,7 +909,6 @@ def create_mmdit_sd3_medium_configs(attn_mode: str): # endregion # region VAE -# TODO support xformers VAE_SCALE_FACTOR = 1.5305 VAE_SHIFT_FACTOR = 0.0609 @@ -1322,759 +1155,4 @@ def process_out(latent): return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR -class VAEOutput: - def __init__(self, latent): - self.latent = latent - - @property - def latent_dist(self): - return self - - def sample(self): - return self.latent - - -class VAEWrapper: - def __init__(self, vae): - self.vae = vae - - @property - def device(self): - return self.vae.device - - @property - def dtype(self): - return self.vae.dtype - - # latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - def encode(self, image): - return VAEOutput(self.vae.encode(image)) - - -# endregion - - -# region Text Encoder -class CLIPAttention(torch.nn.Module): - def __init__(self, embed_dim, heads, dtype, device, mode="xformers"): - super().__init__() - self.heads = heads - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.attn_mode = mode - - def set_attn_mode(self, mode): - self.attn_mode = mode - - def forward(self, x, mask=None): - q = self.q_proj(x) - k = self.k_proj(x) - v = self.v_proj(x) - out = attention(q, k, v, self.heads, mask, mode=self.attn_mode) - return self.out_proj(out) - - -ACTIVATIONS = { - "quick_gelu": lambda: (lambda a: a * torch.sigmoid(1.702 * a)), - # "gelu": torch.nn.functional.gelu, - "gelu": lambda: nn.GELU(), -} - - -class CLIPLayer(torch.nn.Module): - def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): - super().__init__() - self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) - self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) - self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) - # # self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) - # self.mlp = Mlp( - # embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device - # ) - self.mlp = MLP(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation]) - self.mlp.to(device=device, dtype=dtype) - - def forward(self, x, mask=None): - x += self.self_attn(self.layer_norm1(x), mask) - x += self.mlp(self.layer_norm2(x)) - return x - - -class CLIPEncoder(torch.nn.Module): - def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): - super().__init__() - self.layers = torch.nn.ModuleList( - [CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)] - ) - - def forward(self, x, mask=None, intermediate_output=None): - if intermediate_output is not None: - if intermediate_output < 0: - intermediate_output = len(self.layers) + intermediate_output - intermediate = None - for i, l in enumerate(self.layers): - x = l(x, mask) - if i == intermediate_output: - intermediate = x.clone() - return x, intermediate - - -class CLIPEmbeddings(torch.nn.Module): - def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None): - super().__init__() - self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) - self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) - - def forward(self, input_tokens): - return self.token_embedding(input_tokens) + self.position_embedding.weight - - -class CLIPTextModel_(torch.nn.Module): - def __init__(self, config_dict, dtype, device): - num_layers = config_dict["num_hidden_layers"] - embed_dim = config_dict["hidden_size"] - heads = config_dict["num_attention_heads"] - intermediate_size = config_dict["intermediate_size"] - intermediate_activation = config_dict["hidden_act"] - super().__init__() - self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) - self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) - self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) - - def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True): - x = self.embeddings(input_tokens) - - if x.dtype == torch.bfloat16: - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=torch.float32, device=x.device).fill_(float("-inf")).triu_(1) - causal_mask = causal_mask.to(dtype=x.dtype) - else: - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) - - x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output) - x = self.final_layer_norm(x) - if i is not None and final_layer_norm_intermediate: - i = self.final_layer_norm(i) - pooled_output = x[ - torch.arange(x.shape[0], device=x.device), - input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1), - ] - return x, i, pooled_output - - -class CLIPTextModel(torch.nn.Module): - def __init__(self, config_dict, dtype, device): - super().__init__() - self.num_layers = config_dict["num_hidden_layers"] - self.text_model = CLIPTextModel_(config_dict, dtype, device) - embed_dim = config_dict["hidden_size"] - self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) - self.text_projection.weight.copy_(torch.eye(embed_dim)) - self.dtype = dtype - - def get_input_embeddings(self): - return self.text_model.embeddings.token_embedding - - def set_input_embeddings(self, embeddings): - self.text_model.embeddings.token_embedding = embeddings - - def forward(self, *args, **kwargs): - x = self.text_model(*args, **kwargs) - out = self.text_projection(x[2]) - return (x[0], x[1], out, x[2]) - - -class ClipTokenWeightEncoder: - # def encode_token_weights(self, token_weight_pairs): - # tokens = list(map(lambda a: a[0], token_weight_pairs[0])) - # out, pooled = self([tokens]) - # if pooled is not None: - # first_pooled = pooled[0:1] - # else: - # first_pooled = pooled - # output = [out[0:1]] - # return torch.cat(output, dim=-2), first_pooled - - # fix to support batched inputs - # : Union[List[Tuple[torch.Tensor, torch.Tensor]], List[List[Tuple[torch.Tensor, torch.Tensor]]]] - def encode_token_weights(self, list_of_token_weight_pairs): - has_batch = isinstance(list_of_token_weight_pairs[0][0], list) - - if has_batch: - list_of_tokens = [] - for pairs in list_of_token_weight_pairs: - tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] - list_of_tokens.append(tokens) - else: - if isinstance(list_of_token_weight_pairs[0], torch.Tensor): - list_of_tokens = [list(list_of_token_weight_pairs[0])] - else: - list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] - - out, pooled = self(list_of_tokens) - if has_batch: - return out, pooled - else: - if pooled is not None: - first_pooled = pooled[0:1] - else: - first_pooled = pooled - output = [out[0:1]] - return torch.cat(output, dim=-2), first_pooled - - -class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): - """Uses the CLIP transformer encoder for text (from huggingface)""" - - LAYERS = ["last", "pooled", "hidden"] - - def __init__( - self, - device="cpu", - max_length=77, - layer="last", - layer_idx=None, - textmodel_json_config=None, - dtype=None, - model_class=CLIPTextModel, - special_tokens={"start": 49406, "end": 49407, "pad": 49407}, - layer_norm_hidden_state=True, - return_projected_pooled=True, - ): - super().__init__() - assert layer in self.LAYERS - self.transformer = model_class(textmodel_json_config, dtype, device) - self.num_layers = self.transformer.num_layers - self.max_length = max_length - self.transformer = self.transformer.eval() - for param in self.parameters(): - param.requires_grad = False - self.layer = layer - self.layer_idx = None - self.special_tokens = special_tokens - self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) - self.layer_norm_hidden_state = layer_norm_hidden_state - self.return_projected_pooled = return_projected_pooled - if layer == "hidden": - assert layer_idx is not None - assert abs(layer_idx) < self.num_layers - self.set_clip_options({"layer": layer_idx}) - self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) - - @property - def device(self): - return next(self.parameters()).device - - @property - def dtype(self): - return next(self.parameters()).dtype - - def gradient_checkpointing_enable(self): - logger.warning("Gradient checkpointing is not supported for this model") - - def set_attn_mode(self, mode): - raise NotImplementedError("This model does not support setting the attention mode") - - def set_clip_options(self, options): - layer_idx = options.get("layer", self.layer_idx) - self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) - if layer_idx is None or abs(layer_idx) > self.num_layers: - self.layer = "last" - else: - self.layer = "hidden" - self.layer_idx = layer_idx - - def forward(self, tokens): - backup_embeds = self.transformer.get_input_embeddings() - device = backup_embeds.weight.device - tokens = torch.LongTensor(tokens).to(device) - outputs = self.transformer( - tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state - ) - self.transformer.set_input_embeddings(backup_embeds) - if self.layer == "last": - z = outputs[0] - else: - z = outputs[1] - pooled_output = None - if len(outputs) >= 3: - if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: - pooled_output = outputs[3].float() - elif outputs[2] is not None: - pooled_output = outputs[2].float() - return z.float(), pooled_output - - def set_attn_mode(self, mode): - clip_text_model = self.transformer.text_model - for layer in clip_text_model.encoder.layers: - layer.self_attn.set_attn_mode(mode) - - -class SDXLClipG(SDClipModel): - """Wraps the CLIP-G model into the SD-CLIP-Model interface""" - - def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None): - if layer == "penultimate": - layer = "hidden" - layer_idx = -2 - super().__init__( - device=device, - layer=layer, - layer_idx=layer_idx, - textmodel_json_config=config, - dtype=dtype, - special_tokens={"start": 49406, "end": 49407, "pad": 0}, - layer_norm_hidden_state=False, - ) - - def set_attn_mode(self, mode): - clip_text_model = self.transformer.text_model - for layer in clip_text_model.encoder.layers: - layer.self_attn.set_attn_mode(mode) - - -class T5XXLModel(SDClipModel): - """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience""" - - def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None): - super().__init__( - device=device, - layer=layer, - layer_idx=layer_idx, - textmodel_json_config=config, - dtype=dtype, - special_tokens={"end": 1, "pad": 0}, - model_class=T5, - ) - - def set_attn_mode(self, mode): - t5: T5 = self.transformer - for t5block in t5.encoder.block: - t5block: T5Block - t5layer: T5LayerSelfAttention = t5block.layer[0] - t5SaSa: T5Attention = t5layer.SelfAttention - t5SaSa.set_attn_mode(mode) - - -################################################################################################# -### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl -################################################################################################# - -""" -class T5XXLTokenizer(SDTokenizer): - ""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"" - - def __init__(self): - super().__init__( - pad_with_end=False, - tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), - has_start_token=False, - pad_to_max_length=False, - max_length=99999999, - min_length=77, - ) -""" - - -class T5LayerNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): - super().__init__() - self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) - self.variance_epsilon = eps - - # def forward(self, x): - # variance = x.pow(2).mean(-1, keepdim=True) - # x = x * torch.rsqrt(variance + self.variance_epsilon) - # return self.weight.to(device=x.device, dtype=x.dtype) * x - - # copy from transformers' T5LayerNorm - def forward(self, hidden_states): - # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated - # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for - # half-precision inputs is done in fp32 - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - - -class T5DenseGatedActDense(torch.nn.Module): - def __init__(self, model_dim, ff_dim, dtype, device): - super().__init__() - self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) - self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) - self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) - - def forward(self, x): - hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") - hidden_linear = self.wi_1(x) - x = hidden_gelu * hidden_linear - x = self.wo(x) - return x - - -class T5LayerFF(torch.nn.Module): - def __init__(self, model_dim, ff_dim, dtype, device): - super().__init__() - self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device) - self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) - - def forward(self, x): - forwarded_states = self.layer_norm(x) - forwarded_states = self.DenseReluDense(forwarded_states) - x += forwarded_states - return x - - -class T5Attention(torch.nn.Module): - def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device): - super().__init__() - # Mesh TensorFlow initialization to avoid scaling before softmax - self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) - self.num_heads = num_heads - self.relative_attention_bias = None - if relative_attention_bias: - self.relative_attention_num_buckets = 32 - self.relative_attention_max_distance = 128 - self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device) - - self.attn_mode = "xformers" # TODO 何とかする - - def set_attn_mode(self, mode): - self.attn_mode = mode - - @staticmethod - def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on - - Args: - relative_position: an int32 Tensor - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - - Returns: - a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - else: - relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) - # now relative_position is in the range [0, inf) - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) - ).to(torch.long) - relative_position_if_large = torch.min( - relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) - ) - relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) - return relative_buckets - - def compute_bias(self, query_length, key_length, device): - """Compute binned relative position bias""" - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] - relative_position = memory_position - context_position # shape (query_length, key_length) - relative_position_bucket = self._relative_position_bucket( - relative_position, # shape (query_length, key_length) - bidirectional=True, - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) - values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) - return values - - def forward(self, x, past_bias=None): - q = self.q(x) - k = self.k(x) - v = self.v(x) - if self.relative_attention_bias is not None: - past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) - if past_bias is not None: - mask = past_bias - out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask, mode=self.attn_mode) - return self.o(out), past_bias - - -class T5LayerSelfAttention(torch.nn.Module): - def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): - super().__init__() - self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device) - self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) - - def forward(self, x, past_bias=None): - output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias) - x += output - return x, past_bias - - -class T5Block(torch.nn.Module): - def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): - super().__init__() - self.layer = torch.nn.ModuleList() - self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device)) - self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device)) - - def forward(self, x, past_bias=None): - x, past_bias = self.layer[0](x, past_bias) - - # copy from transformers' T5Block - # clamp inf values to enable fp16 training - if x.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(x).any(), - torch.finfo(x.dtype).max - 1000, - torch.finfo(x.dtype).max, - ) - x = torch.clamp(x, min=-clamp_value, max=clamp_value) - - x = self.layer[-1](x) - # clamp inf values to enable fp16 training - if x.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(x).any(), - torch.finfo(x.dtype).max - 1000, - torch.finfo(x.dtype).max, - ) - x = torch.clamp(x, min=-clamp_value, max=clamp_value) - - return x, past_bias - - -class T5Stack(torch.nn.Module): - def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device): - super().__init__() - self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) - self.block = torch.nn.ModuleList( - [ - T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) - for i in range(num_layers) - ] - ) - self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) - - def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True): - intermediate = None - x = self.embed_tokens(input_ids) - past_bias = None - for i, l in enumerate(self.block): - # uncomment to debug layerwise output: fp16 may cause issues - # print(i, x.mean(), x.std()) - x, past_bias = l(x, past_bias) - if i == intermediate_output: - intermediate = x.clone() - # print(x.mean(), x.std()) - x = self.final_layer_norm(x) - if intermediate is not None and final_layer_norm_intermediate: - intermediate = self.final_layer_norm(intermediate) - # print(x.mean(), x.std()) - return x, intermediate - - -class T5(torch.nn.Module): - def __init__(self, config_dict, dtype, device): - super().__init__() - self.num_layers = config_dict["num_layers"] - self.encoder = T5Stack( - self.num_layers, - config_dict["d_model"], - config_dict["d_model"], - config_dict["d_ff"], - config_dict["num_heads"], - config_dict["vocab_size"], - dtype, - device, - ) - self.dtype = dtype - - def get_input_embeddings(self): - return self.encoder.embed_tokens - - def set_input_embeddings(self, embeddings): - self.encoder.embed_tokens = embeddings - - def forward(self, *args, **kwargs): - return self.encoder(*args, **kwargs) - - -def create_clip_l(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None): - r""" - state_dict is not loaded, but updated with missing keys - """ - CLIPL_CONFIG = { - "hidden_act": "quick_gelu", - "hidden_size": 768, - "intermediate_size": 3072, - "num_attention_heads": 12, - "num_hidden_layers": 12, - } - with torch.no_grad(): - clip_l = SDClipModel( - layer="hidden", - layer_idx=-2, - device=device, - dtype=dtype, - layer_norm_hidden_state=False, - return_projected_pooled=False, - textmodel_json_config=CLIPL_CONFIG, - ) - clip_l.gradient_checkpointing_enable() - if state_dict is not None: - # update state_dict if provided to include logit_scale and text_projection.weight avoid errors - if "logit_scale" not in state_dict: - state_dict["logit_scale"] = clip_l.logit_scale - if "transformer.text_projection.weight" not in state_dict: - state_dict["transformer.text_projection.weight"] = clip_l.transformer.text_projection.weight - return clip_l - - -def create_clip_g(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None): - r""" - state_dict is not loaded, but updated with missing keys - """ - CLIPG_CONFIG = { - "hidden_act": "gelu", - "hidden_size": 1280, - "intermediate_size": 5120, - "num_attention_heads": 20, - "num_hidden_layers": 32, - } - with torch.no_grad(): - clip_g = SDXLClipG(CLIPG_CONFIG, device=device, dtype=dtype) - if state_dict is not None: - if "logit_scale" not in state_dict: - state_dict["logit_scale"] = clip_g.logit_scale - return clip_g - - -def create_t5xxl(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> T5XXLModel: - T5_CONFIG = {"d_ff": 10240, "d_model": 4096, "num_heads": 64, "num_layers": 24, "vocab_size": 32128} - with torch.no_grad(): - t5 = T5XXLModel(T5_CONFIG, dtype=dtype, device=device) - if state_dict is not None: - if "logit_scale" not in state_dict: - state_dict["logit_scale"] = t5.logit_scale - if "transformer.shared.weight" in state_dict: - state_dict.pop("transformer.shared.weight") - return t5 - - -""" - # snippet for using the T5 model from transformers - - from transformers import T5EncoderModel, T5Config - import accelerate - import json - - T5_CONFIG_JSON = "" -{ - "architectures": [ - "T5EncoderModel" - ], - "classifier_dropout": 0.0, - "d_ff": 10240, - "d_kv": 64, - "d_model": 4096, - "decoder_start_token_id": 0, - "dense_act_fn": "gelu_new", - "dropout_rate": 0.1, - "eos_token_id": 1, - "feed_forward_proj": "gated-gelu", - "initializer_factor": 1.0, - "is_encoder_decoder": true, - "is_gated_act": true, - "layer_norm_epsilon": 1e-06, - "model_type": "t5", - "num_decoder_layers": 24, - "num_heads": 64, - "num_layers": 24, - "output_past": true, - "pad_token_id": 0, - "relative_attention_max_distance": 128, - "relative_attention_num_buckets": 32, - "tie_word_embeddings": false, - "torch_dtype": "float16", - "transformers_version": "4.41.2", - "use_cache": true, - "vocab_size": 32128 -} -"" - config = json.loads(T5_CONFIG_JSON) - config = T5Config(**config) - - # model = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3") - # print(model.config) - # # model(**load_model.config) - - # with accelerate.init_empty_weights(): - model = T5EncoderModel._from_config(config) # , torch_dtype=dtype) - for key in list(state_dict.keys()): - if key.startswith("transformer."): - new_key = key[len("transformer.") :] - state_dict[new_key] = state_dict.pop(key) - - info = model.load_state_dict(state_dict) - print(info) - model.set_attn_mode = lambda x: None - # model.to("cpu") - - _self = model - - def enc(list_of_token_weight_pairs): - has_batch = isinstance(list_of_token_weight_pairs[0][0], list) - - if has_batch: - list_of_tokens = [] - for pairs in list_of_token_weight_pairs: - tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] - list_of_tokens.append(tokens) - else: - list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] - - list_of_tokens = np.array(list_of_tokens) - list_of_tokens = torch.from_numpy(list_of_tokens).to("cuda", dtype=torch.long) - out = _self(list_of_tokens) - pooled = None - if has_batch: - return out, pooled - else: - if pooled is not None: - first_pooled = pooled[0:1] - else: - first_pooled = pooled - return out[0], first_pooled - # output = [out[0:1]] - # return torch.cat(output, dim=-2), first_pooled - - model.encode_token_weights = enc - - return model -""" - # endregion diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index e819d440c..9282482d9 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -11,8 +11,8 @@ from accelerate import Accelerator, PartialState from tqdm import tqdm from PIL import Image +from transformers import CLIPTextModelWithProjection, T5EncoderModel -from library import sd3_models, sd3_utils, strategy_base, train_util from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -28,60 +28,16 @@ logger = logging.getLogger(__name__) -from .sdxl_train_util import match_mixed_precision - - -def load_target_model( - model_type: str, - args: argparse.Namespace, - state_dict: dict, - accelerator: Accelerator, - attn_mode: str, - model_dtype: Optional[torch.dtype], - device: Optional[torch.device], -) -> Union[ - sd3_models.MMDiT, - Optional[sd3_models.SDClipModel], - Optional[sd3_models.SDXLClipG], - Optional[sd3_models.T5XXLModel], - sd3_models.SDVAE, -]: - loading_device = device if device is not None else (accelerator.device if args.lowram else "cpu") - - for pi in range(accelerator.state.num_processes): - if pi == accelerator.state.local_process_index: - logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") - - if model_type == "mmdit": - model = sd3_utils.load_mmdit(state_dict, attn_mode, model_dtype, loading_device) - elif model_type == "clip_l": - model = sd3_utils.load_clip_l(state_dict, args.clip_l, attn_mode, model_dtype, loading_device) - elif model_type == "clip_g": - model = sd3_utils.load_clip_g(state_dict, args.clip_g, attn_mode, model_dtype, loading_device) - elif model_type == "t5xxl": - model = sd3_utils.load_t5xxl(state_dict, args.t5xxl, attn_mode, model_dtype, loading_device) - elif model_type == "vae": - model = sd3_utils.load_vae(state_dict, args.vae, model_dtype, loading_device) - else: - raise ValueError(f"Unknown model type: {model_type}") - - # work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device - if args.lowram: - model = model.to(accelerator.device) - - clean_memory_on_device(accelerator.device) - accelerator.wait_for_everyone() - - return model +from library import sd3_models, sd3_utils, strategy_base, train_util def save_models( ckpt_path: str, - mmdit: sd3_models.MMDiT, - vae: sd3_models.SDVAE, - clip_l: sd3_models.SDClipModel, - clip_g: sd3_models.SDXLClipG, - t5xxl: Optional[sd3_models.T5XXLModel], + mmdit: Optional[sd3_models.MMDiT], + vae: Optional[sd3_models.SDVAE], + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], sai_metadata: Optional[dict], save_dtype: Optional[torch.dtype] = None, ): @@ -101,14 +57,25 @@ def update_sd(prefix, sd): update_sd("model.diffusion_model.", mmdit.state_dict()) update_sd("first_stage_model.", vae.state_dict()) + # do not support unified checkpoint format for now + # if clip_l is not None: + # update_sd("text_encoders.clip_l.", clip_l.state_dict()) + # if clip_g is not None: + # update_sd("text_encoders.clip_g.", clip_g.state_dict()) + # if t5xxl is not None: + # update_sd("text_encoders.t5xxl.", t5xxl.state_dict()) + + save_file(state_dict, ckpt_path, metadata=sai_metadata) + if clip_l is not None: - update_sd("text_encoders.clip_l.", clip_l.state_dict()) + clip_l_path = ckpt_path.replace(".safetensors", "_clip_l.safetensors") + save_file(clip_l.state_dict(), clip_l_path) if clip_g is not None: - update_sd("text_encoders.clip_g.", clip_g.state_dict()) + clip_g_path = ckpt_path.replace(".safetensors", "_clip_g.safetensors") + save_file(clip_g.state_dict(), clip_g_path) if t5xxl is not None: - update_sd("text_encoders.t5xxl.", t5xxl.state_dict()) - - save_file(state_dict, ckpt_path, metadata=sai_metadata) + t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors") + save_file(t5xxl.state_dict(), t5xxl_path) def save_sd3_model_on_train_end( @@ -116,9 +83,9 @@ def save_sd3_model_on_train_end( save_dtype: torch.dtype, epoch: int, global_step: int, - clip_l: sd3_models.SDClipModel, - clip_g: sd3_models.SDXLClipG, - t5xxl: Optional[sd3_models.T5XXLModel], + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], mmdit: sd3_models.MMDiT, vae: sd3_models.SDVAE, ): @@ -141,9 +108,9 @@ def save_sd3_model_on_epoch_end_or_stepwise( epoch: int, num_train_epochs: int, global_step: int, - clip_l: sd3_models.SDClipModel, - clip_g: sd3_models.SDXLClipG, - t5xxl: Optional[sd3_models.T5XXLModel], + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], mmdit: sd3_models.MMDiT, vae: sd3_models.SDVAE, ): @@ -208,23 +175,27 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用", ) parser.add_argument( - "--save_clip", action="store_true", help="save CLIP models to checkpoint / CLIPモデルをチェックポイントに保存する" + "--save_clip", + action="store_true", + help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません", ) parser.add_argument( - "--save_t5xxl", action="store_true", help="save T5-XXL model to checkpoint / T5-XXLモデルをチェックポイントに保存する" + "--save_t5xxl", + action="store_true", + help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません", ) parser.add_argument( "--t5xxl_device", type=str, default=None, - help="T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用", + help="[DOES NOT WORK] not supported yet. T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用", ) parser.add_argument( "--t5xxl_dtype", type=str, default=None, - help="T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用", + help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用", ) # copy from Diffusers @@ -233,16 +204,25 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], + help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム", ) parser.add_argument( - "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + "--logit_mean", + type=float, + default=0.0, + help="mean to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合の平均", + ) + parser.add_argument( + "--logit_std", + type=float, + default=1.0, + help="std to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合のstd", ) - parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") parser.add_argument( "--mode_scale", type=float, default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効", ) @@ -283,7 +263,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin # temporary copied from sd3_minimal_inferece.py -def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): +def get_all_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): start = sampling.timestep(sampling.sigma_max) end = sampling.timestep(sampling.sigma_min) timesteps = torch.linspace(start, end, steps) @@ -327,7 +307,7 @@ def do_sample( model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 - sigmas = get_sigmas(model_sampling, steps).to(device) + sigmas = get_all_sigmas(model_sampling, steps).to(device) noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas)) @@ -371,37 +351,6 @@ def do_sample( return x -def load_prompts(prompt_file: str) -> List[Dict]: - # read prompts - if prompt_file.endswith(".txt"): - with open(prompt_file, "r", encoding="utf-8") as f: - lines = f.readlines() - prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] - elif prompt_file.endswith(".toml"): - with open(prompt_file, "r", encoding="utf-8") as f: - data = toml.load(f) - prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] - elif prompt_file.endswith(".json"): - with open(prompt_file, "r", encoding="utf-8") as f: - prompts = json.load(f) - - # preprocess prompts - for i in range(len(prompts)): - prompt_dict = prompts[i] - if isinstance(prompt_dict, str): - from library.train_util import line_to_prompt_dict - - prompt_dict = line_to_prompt_dict(prompt_dict) - prompts[i] = prompt_dict - assert isinstance(prompt_dict, dict) - - # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. - prompt_dict["enum"] = i - prompt_dict.pop("subset", None) - - return prompts - - def sample_images( accelerator: Accelerator, args: argparse.Namespace, @@ -440,7 +389,7 @@ def sample_images( text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) - prompts = load_prompts(args.sample_prompts) + prompts = train_util.load_prompts(args.sample_prompts) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) @@ -510,7 +459,7 @@ def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, mmdit: sd3_models.MMDiT, - text_encoders: List[Union[sd3_models.SDClipModel, sd3_models.SDXLClipG, sd3_models.T5XXLModel]], + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], vae: sd3_models.SDVAE, save_dir, prompt_dict, @@ -568,7 +517,7 @@ def sample_image_inference( l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(prompt) te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) - lg_out, t5_out, pooled = te_outputs + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = te_outputs cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # encode negative prompts @@ -578,7 +527,7 @@ def sample_image_inference( l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(negative_prompt) neg_te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) - lg_out, t5_out, pooled = neg_te_outputs + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = neg_te_outputs neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # sample image @@ -609,14 +558,9 @@ def sample_image_inference( wandb_tracker = accelerator.get_tracker("wandb") import wandb + # not to commit images to avoid inconsistency between training and logging steps - wandb_tracker.log( - {f"sample_{i}": wandb.Image( - image, - caption=prompt # positive prompt as a caption - )}, - commit=False - ) + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption # region Diffusers @@ -886,4 +830,78 @@ def __len__(self): return self.config.num_train_timesteps +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz = latents.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=device) + + # Add noise according to flow matching. + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input, timesteps, sigmas + + # endregion diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 5849518fb..9ad995d81 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -1,9 +1,12 @@ +from dataclasses import dataclass import math -from typing import Dict, Optional, Union +import re +from typing import Dict, List, Optional, Union import torch import safetensors from safetensors.torch import load_file from accelerate import init_empty_weights +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig from .utils import setup_logging @@ -19,18 +22,61 @@ # region models +# TODO remove dependency on flux_utils +from library.utils import load_safetensors +from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl -def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False): - if disable_mmap: - return safetensors.torch.load(open(path, "rb").read()) + +def analyze_state_dict_state(state_dict: Dict, prefix: str = ""): + logger.info(f"Analyzing state dict state...") + + # analyze configs + patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2] + depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64 + num_patches = state_dict[f"{prefix}pos_embed"].shape[1] + pos_embed_max_size = round(math.sqrt(num_patches)) + adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1] + context_shape = state_dict[f"{prefix}context_embedder.weight"].shape + qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None + + # x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])) + x_block_self_attn_layers = [] + re_attn = re.compile(r".(\d+).x_block.attn2.ln_k.weight") + for key in list(state_dict.keys()): + m = re_attn.match(key) + if m: + x_block_self_attn_layers.append(int(m.group(1))) + + assert len(x_block_self_attn_layers) == 0, "x_block_self_attn_layers is not supported" + + context_embedder_in_features = context_shape[1] + context_embedder_out_features = context_shape[0] + + # only supports 3-5-large and 3-medium + if qk_norm is not None: + model_type = "3-5-large" else: - try: - return load_file(path, device=dvc) - except: - return load_file(path) # prevent device invalid Error + model_type = "3-medium" + + params = sd3_models.SD3Params( + patch_size=patch_size, + depth=depth, + num_patches=num_patches, + pos_embed_max_size=pos_embed_max_size, + adm_in_channels=adm_in_channels, + qk_norm=qk_norm, + x_block_self_attn_layers=x_block_self_attn_layers, + context_embedder_in_features=context_embedder_in_features, + context_embedder_out_features=context_embedder_out_features, + model_type=model_type, + ) + logger.info(f"Analyzed state dict state: {params}") + return params -def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]): +def load_mmdit( + state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch" +) -> sd3_models.MMDiT: mmdit_sd = {} mmdit_prefix = "model.diffusion_model." @@ -40,8 +86,9 @@ def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torc # load MMDiT logger.info("Building MMDit") + params = analyze_state_dict_state(mmdit_sd) with init_empty_weights(): - mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) + mmdit = sd3_models.create_sd3_mmdit(params, attn_mode) logger.info("Loading state dict...") info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype) @@ -50,20 +97,14 @@ def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torc def load_clip_l( - state_dict: Dict, clip_l_path: Optional[str], - attn_mode: str, - clip_dtype: Optional[Union[str, torch.dtype]], + dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], disable_mmap: bool = False, + state_dict: Optional[Dict] = None, ): clip_l_sd = None - if clip_l_path: - logger.info(f"Loading clip_l from {clip_l_path}...") - clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap) - for key in list(clip_l_sd.keys()): - clip_l_sd["transformer." + key] = clip_l_sd.pop(key) - else: + if clip_l_path is None: if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: # found clip_l: remove prefix "text_encoders.clip_l." logger.info("clip_l is included in the checkpoint") @@ -72,34 +113,58 @@ def load_clip_l( for k in list(state_dict.keys()): if k.startswith(prefix): clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + elif clip_l_path is None: + logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided") + return None + + # load clip_l + logger.info("Building CLIP-L") + config = CLIPTextConfig( + vocab_size=49408, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=768, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + clip = CLIPTextModelWithProjection(config) if clip_l_sd is None: - clip_l = None - else: - logger.info("Building ClipL") - clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) - logger.info("Loading state dict...") - info = clip_l.load_state_dict(clip_l_sd) - logger.info(f"Loaded ClipL: {info}") - clip_l.set_attn_mode(attn_mode) - return clip_l + logger.info(f"Loading state dict from {clip_l_path}") + clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + + if "text_projection.weight" not in clip_l_sd: + logger.info("Adding text_projection.weight to clip_l_sd") + clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device) + + info = clip.load_state_dict(clip_l_sd, strict=False, assign=True) + logger.info(f"Loaded CLIP-L: {info}") + return clip def load_clip_g( - state_dict: Dict, clip_g_path: Optional[str], - attn_mode: str, - clip_dtype: Optional[Union[str, torch.dtype]], + dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], disable_mmap: bool = False, + state_dict: Optional[Dict] = None, ): clip_g_sd = None - if clip_g_path: - logger.info(f"Loading clip_g from {clip_g_path}...") - clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap) - for key in list(clip_g_sd.keys()): - clip_g_sd["transformer." + key] = clip_g_sd.pop(key) - else: + if state_dict is not None: if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: # found clip_g: remove prefix "text_encoders.clip_g." logger.info("clip_g is included in the checkpoint") @@ -108,34 +173,53 @@ def load_clip_g( for k in list(state_dict.keys()): if k.startswith(prefix): clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + elif clip_g_path is None: + logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided") + return None + + # load clip_g + logger.info("Building CLIP-G") + config = CLIPTextConfig( + vocab_size=49408, + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=32, + num_attention_heads=20, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=1280, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + clip = CLIPTextModelWithProjection(config) if clip_g_sd is None: - clip_g = None - else: - logger.info("Building ClipG") - clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) - logger.info("Loading state dict...") - info = clip_g.load_state_dict(clip_g_sd) - logger.info(f"Loaded ClipG: {info}") - clip_g.set_attn_mode(attn_mode) - return clip_g + logger.info(f"Loading state dict from {clip_g_path}") + clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = clip.load_state_dict(clip_g_sd, strict=False, assign=True) + logger.info(f"Loaded CLIP-G: {info}") + return clip def load_t5xxl( - state_dict: Dict, t5xxl_path: Optional[str], - attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], disable_mmap: bool = False, + state_dict: Optional[Dict] = None, ): t5xxl_sd = None - if t5xxl_path: - logger.info(f"Loading t5xxl from {t5xxl_path}...") - t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap) - for key in list(t5xxl_sd.keys()): - t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) - else: + if state_dict is not None: if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: # found t5xxl: remove prefix "text_encoders.t5xxl." logger.info("t5xxl is included in the checkpoint") @@ -144,29 +228,19 @@ def load_t5xxl( for k in list(state_dict.keys()): if k.startswith(prefix): t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + elif t5xxl_path is None: + logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided") + return None - if t5xxl_sd is None: - t5xxl = None - else: - logger.info("Building T5XXL") - - # workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device - t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd) - t5xxl.to(dtype=dtype) - - logger.info("Loading state dict...") - info = t5xxl.load_state_dict(t5xxl_sd) - logger.info(f"Loaded T5XXL: {info}") - t5xxl.set_attn_mode(attn_mode) - return t5xxl + return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd) def load_vae( - state_dict: Dict, vae_path: Optional[str], vae_dtype: Optional[Union[str, torch.dtype]], device: Optional[Union[str, torch.device]], disable_mmap: bool = False, + state_dict: Optional[Dict] = None, ): vae_sd = {} if vae_path: @@ -181,299 +255,15 @@ def load_vae( vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) logger.info("Building VAE") - vae = sd3_models.SDVAE() + vae = sd3_models.SDVAE(vae_dtype, device) logger.info("Loading state dict...") info = vae.load_state_dict(vae_sd) logger.info(f"Loaded VAE: {info}") - vae.to(device=device, dtype=vae_dtype) + vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype return vae -def load_models( - ckpt_path: str, - clip_l_path: str, - clip_g_path: str, - t5xxl_path: str, - vae_path: str, - attn_mode: str, - device: Union[str, torch.device], - weight_dtype: Optional[Union[str, torch.dtype]] = None, - disable_mmap: bool = False, - clip_dtype: Optional[Union[str, torch.dtype]] = None, - t5xxl_device: Optional[Union[str, torch.device]] = None, - t5xxl_dtype: Optional[Union[str, torch.dtype]] = None, - vae_dtype: Optional[Union[str, torch.dtype]] = None, -): - """ - Load SD3 models from checkpoint files. - - Args: - ckpt_path: Path to the SD3 checkpoint file. - clip_l_path: Path to the clip_l checkpoint file. - clip_g_path: Path to the clip_g checkpoint file. - t5xxl_path: Path to the t5xxl checkpoint file. - vae_path: Path to the VAE checkpoint file. - attn_mode: Attention mode for MMDiT model. - device: Device for MMDiT model. - weight_dtype: Default dtype of weights for all models. This is weight dtype, so the model dtype may be different. - disable_mmap: Disable memory mapping when loading state dict. - clip_dtype: Dtype for Clip models, or None to use default dtype. - t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device. - t5xxl_dtype: Dtype for T5XXL model, or None to use default dtype. - vae_dtype: Dtype for VAE model, or None to use default dtype. - - Returns: - Tuple of MMDiT, ClipL, ClipG, T5XXL, and VAE models. - """ - - # In SD1/2 and SDXL, the model is created with empty weights and then loaded with state dict. - # However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict. - # Therefore, we need clip_dtype and t5xxl_dtype. - - def load_state_dict(path: str, dvc: Union[str, torch.device] = device): - if disable_mmap: - return safetensors.torch.load(open(path, "rb").read()) - else: - try: - return load_file(path, device=dvc) - except: - return load_file(path) # prevent device invalid Error - - t5xxl_device = t5xxl_device or device - clip_dtype = clip_dtype or weight_dtype or torch.float32 - t5xxl_dtype = t5xxl_dtype or weight_dtype or torch.float32 - vae_dtype = vae_dtype or weight_dtype or torch.float32 - - logger.info(f"Loading SD3 models from {ckpt_path}...") - state_dict = load_state_dict(ckpt_path) - - # load clip_l - clip_l_sd = None - if clip_l_path: - logger.info(f"Loading clip_l from {clip_l_path}...") - clip_l_sd = load_state_dict(clip_l_path) - for key in list(clip_l_sd.keys()): - clip_l_sd["transformer." + key] = clip_l_sd.pop(key) - else: - if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: - # found clip_l: remove prefix "text_encoders.clip_l." - logger.info("clip_l is included in the checkpoint") - clip_l_sd = {} - prefix = "text_encoders.clip_l." - for k in list(state_dict.keys()): - if k.startswith(prefix): - clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) - - # load clip_g - clip_g_sd = None - if clip_g_path: - logger.info(f"Loading clip_g from {clip_g_path}...") - clip_g_sd = load_state_dict(clip_g_path) - for key in list(clip_g_sd.keys()): - clip_g_sd["transformer." + key] = clip_g_sd.pop(key) - else: - if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: - # found clip_g: remove prefix "text_encoders.clip_g." - logger.info("clip_g is included in the checkpoint") - clip_g_sd = {} - prefix = "text_encoders.clip_g." - for k in list(state_dict.keys()): - if k.startswith(prefix): - clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) - - # load t5xxl - t5xxl_sd = None - if t5xxl_path: - logger.info(f"Loading t5xxl from {t5xxl_path}...") - t5xxl_sd = load_state_dict(t5xxl_path, t5xxl_device) - for key in list(t5xxl_sd.keys()): - t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) - else: - if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: - # found t5xxl: remove prefix "text_encoders.t5xxl." - logger.info("t5xxl is included in the checkpoint") - t5xxl_sd = {} - prefix = "text_encoders.t5xxl." - for k in list(state_dict.keys()): - if k.startswith(prefix): - t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) - - # MMDiT and VAE - vae_sd = {} - if vae_path: - logger.info(f"Loading VAE from {vae_path}...") - vae_sd = load_state_dict(vae_path) - else: - # remove prefix "first_stage_model." - vae_sd = {} - vae_prefix = "first_stage_model." - for k in list(state_dict.keys()): - if k.startswith(vae_prefix): - vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) - - mmdit_prefix = "model.diffusion_model." - for k in list(state_dict.keys()): - if k.startswith(mmdit_prefix): - state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k) - else: - state_dict.pop(k) # remove other keys - - # load MMDiT - logger.info("Building MMDit") - with init_empty_weights(): - mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) - - logger.info("Loading state dict...") - info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) - logger.info(f"Loaded MMDiT: {info}") - - # load ClipG and ClipL - if clip_l_sd is None: - clip_l = None - else: - logger.info("Building ClipL") - clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) - logger.info("Loading state dict...") - info = clip_l.load_state_dict(clip_l_sd) - logger.info(f"Loaded ClipL: {info}") - clip_l.set_attn_mode(attn_mode) - - if clip_g_sd is None: - clip_g = None - else: - logger.info("Building ClipG") - clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) - logger.info("Loading state dict...") - info = clip_g.load_state_dict(clip_g_sd) - logger.info(f"Loaded ClipG: {info}") - clip_g.set_attn_mode(attn_mode) - - # load T5XXL - if t5xxl_sd is None: - t5xxl = None - else: - logger.info("Building T5XXL") - t5xxl = sd3_models.create_t5xxl(t5xxl_device, t5xxl_dtype, t5xxl_sd) - logger.info("Loading state dict...") - info = t5xxl.load_state_dict(t5xxl_sd) - logger.info(f"Loaded T5XXL: {info}") - t5xxl.set_attn_mode(attn_mode) - - # load VAE - logger.info("Building VAE") - vae = sd3_models.SDVAE() - logger.info("Loading state dict...") - info = vae.load_state_dict(vae_sd) - logger.info(f"Loaded VAE: {info}") - vae.to(device=device, dtype=vae_dtype) - - return mmdit, clip_l, clip_g, t5xxl, vae - - # endregion -# region utils - - -def get_cond( - prompt: str, - tokenizer: sd3_models.SD3Tokenizer, - clip_l: sd3_models.SDClipModel, - clip_g: sd3_models.SDXLClipG, - t5xxl: Optional[sd3_models.T5XXLModel] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, -): - l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt) - print(t5_tokens) - return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype) - - -def get_cond_from_tokens( - l_tokens, - g_tokens, - t5_tokens, - clip_l: sd3_models.SDClipModel, - clip_g: sd3_models.SDXLClipG, - t5xxl: Optional[sd3_models.T5XXLModel] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, -): - l_out, l_pooled = clip_l.encode_token_weights(l_tokens) - g_out, g_pooled = clip_g.encode_token_weights(g_tokens) - lg_out = torch.cat([l_out, g_out], dim=-1) - lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) - if device is not None: - lg_out = lg_out.to(device=device) - l_pooled = l_pooled.to(device=device) - g_pooled = g_pooled.to(device=device) - if dtype is not None: - lg_out = lg_out.to(dtype=dtype) - l_pooled = l_pooled.to(dtype=dtype) - g_pooled = g_pooled.to(dtype=dtype) - - # t5xxl may be in another device (eg. cpu) - if t5_tokens is None: - t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype) - else: - t5_out, _ = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None - if device is not None: - t5_out = t5_out.to(device=device) - if dtype is not None: - t5_out = t5_out.to(dtype=dtype) - - # return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1) - return lg_out, t5_out, torch.cat((l_pooled, g_pooled), dim=-1) - - -# used if other sd3 models is available -r""" -def get_sd3_configs(state_dict: Dict): - # Important configuration values can be quickly determined by checking shapes in the source file - # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change) - # prefix = "model.diffusion_model." - prefix = "" - - patch_size = state_dict[prefix + "x_embedder.proj.weight"].shape[2] - depth = state_dict[prefix + "x_embedder.proj.weight"].shape[0] // 64 - num_patches = state_dict[prefix + "pos_embed"].shape[1] - pos_embed_max_size = round(math.sqrt(num_patches)) - adm_in_channels = state_dict[prefix + "y_embedder.mlp.0.weight"].shape[1] - context_shape = state_dict[prefix + "context_embedder.weight"].shape - context_embedder_config = { - "target": "torch.nn.Linear", - "params": {"in_features": context_shape[1], "out_features": context_shape[0]}, - } - return { - "patch_size": patch_size, - "depth": depth, - "num_patches": num_patches, - "pos_embed_max_size": pos_embed_max_size, - "adm_in_channels": adm_in_channels, - "context_embedder": context_embedder_config, - } - - -def create_mmdit_from_sd3_checkpoint(state_dict: Dict, attn_mode: str = "xformers"): - "" - Doesn't load state dict. - "" - sd3_configs = get_sd3_configs(state_dict) - - mmdit = sd3_models.MMDiT( - input_size=None, - pos_embed_max_size=sd3_configs["pos_embed_max_size"], - patch_size=sd3_configs["patch_size"], - in_channels=16, - adm_in_channels=sd3_configs["adm_in_channels"], - depth=sd3_configs["depth"], - mlp_ratio=4, - qk_norm=None, - num_patches=sd3_configs["num_patches"], - context_size=4096, - attn_mode=attn_mode, - ) - return mmdit -""" class ModelSamplingDiscreteFlow: @@ -509,6 +299,3 @@ def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): # assert max_denoise is False, "max_denoise not implemented" # max_denoise is always True, I'm not sure why it's there return sigma * noise + (1.0 - sigma) * latent_image - - -# endregion diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index 9fde02084..dd08cf004 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -3,7 +3,7 @@ from typing import Any, List, Optional, Tuple, Union import torch import numpy as np -from transformers import CLIPTokenizer, T5TokenizerFast +from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel from library import sd3_utils, train_util from library import sd3_models @@ -48,45 +48,79 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: class Sd3TextEncodingStrategy(TextEncodingStrategy): - def __init__(self) -> None: - pass + def __init__(self, apply_lg_attn_mask: Optional[bool] = None, apply_t5_attn_mask: Optional[bool] = None) -> None: + """ + Args: + apply_t5_attn_mask: Default value for apply_t5_attn_mask. + """ + self.apply_lg_attn_mask = apply_lg_attn_mask + self.apply_t5_attn_mask = apply_t5_attn_mask def encode_tokens( self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], - apply_lg_attn_mask: bool = False, - apply_t5_attn_mask: bool = False, + apply_lg_attn_mask: Optional[bool] = False, + apply_t5_attn_mask: Optional[bool] = False, ) -> List[torch.Tensor]: """ returned embeddings are not masked """ clip_l, clip_g, t5xxl = models + clip_l: CLIPTextModel + clip_g: CLIPTextModelWithProjection + t5xxl: T5EncoderModel + + if apply_lg_attn_mask is None: + apply_lg_attn_mask = self.apply_lg_attn_mask + if apply_t5_attn_mask is None: + apply_t5_attn_mask = self.apply_t5_attn_mask l_tokens, g_tokens, t5_tokens = tokens[:3] - l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] if len(tokens) > 3 else [None, None, None] + + if len(tokens) > 3: + l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] + if not apply_lg_attn_mask: + l_attn_mask = None + g_attn_mask = None + else: + l_attn_mask = l_attn_mask.to(clip_l.device) + g_attn_mask = g_attn_mask.to(clip_g.device) + if not apply_t5_attn_mask: + t5_attn_mask = None + else: + t5_attn_mask = t5_attn_mask.to(t5xxl.device) + else: + l_attn_mask = None + g_attn_mask = None + t5_attn_mask = None + if l_tokens is None: assert g_tokens is None, "g_tokens must be None if l_tokens is None" lg_out = None + lg_pooled = None else: - assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" - l_out, l_pooled = clip_l(l_tokens) - g_out, g_pooled = clip_g(g_tokens) - if apply_lg_attn_mask: - l_out = l_out * l_attn_mask.to(l_out.device).unsqueeze(-1) - g_out = g_out * g_attn_mask.to(g_out.device).unsqueeze(-1) - lg_out = torch.cat([l_out, g_out], dim=-1) + with torch.no_grad(): + assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True) + l_pooled = prompt_embeds[0] + l_out = prompt_embeds.hidden_states[-2] + + prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True) + g_pooled = prompt_embeds[0] + g_out = prompt_embeds.hidden_states[-2] + + lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None + lg_out = torch.cat([l_out, g_out], dim=-1) if t5xxl is not None and t5_tokens is not None: - t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096] - if apply_t5_attn_mask: - t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) + with torch.no_grad(): + t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True) else: t5_out = None - lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None - return [lg_out, t5_out, lg_pooled] + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] # masks are used for attention masking in transformer def concat_encodings( self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor @@ -132,39 +166,38 @@ def is_disk_cached_outputs_expected(self, npz_path: str): return False if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used return False - # t5xxl is optional + if "apply_lg_attn_mask" not in npz: + return False + if "t5_out" not in npz: + return False + if "t5_attn_mask" not in npz: + return False + npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"] + if npz_apply_lg_attn_mask != self.apply_lg_attn_mask: + return False + if "apply_t5_attn_mask" not in npz: + return False + npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] + if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e return True - def mask_lg_attn(self, lg_out: np.ndarray, l_attn_mask: np.ndarray, g_attn_mask: np.ndarray) -> np.ndarray: - l_out = lg_out[..., :768] - g_out = lg_out[..., 768:] # 1280 - l_out = l_out * np.expand_dims(l_attn_mask, -1) # l_out = l_out * l_attn_mask. - g_out = g_out * np.expand_dims(g_attn_mask, -1) # g_out = g_out * g_attn_mask. - return np.concatenate([l_out, g_out], axis=-1) - - def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray: - return t5_out * np.expand_dims(t5_attn_mask, -1) - def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: data = np.load(npz_path) lg_out = data["lg_out"] lg_pooled = data["lg_pooled"] - t5_out = data["t5_out"] if "t5_out" in data else None - - if self.apply_lg_attn_mask: - l_attn_mask = data["clip_l_attn_mask"] - g_attn_mask = data["clip_g_attn_mask"] - lg_out = self.mask_lg_attn(lg_out, l_attn_mask, g_attn_mask) + t5_out = data["t5_out"] - if self.apply_t5_attn_mask and t5_out is not None: - t5_attn_mask = data["t5_attn_mask"] - t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) + l_attn_mask = data["clip_l_attn_mask"] + g_attn_mask = data["clip_g_attn_mask"] + t5_attn_mask = data["t5_attn_mask"] - return [lg_out, t5_out, lg_pooled] + # apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List @@ -174,7 +207,7 @@ def cache_batch_outputs( tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): - lg_out, t5_out, lg_pooled = sd3_text_encoding_strategy.encode_tokens( + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens( tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask ) @@ -182,38 +215,41 @@ def cache_batch_outputs( lg_out = lg_out.float() if lg_pooled.dtype == torch.bfloat16: lg_pooled = lg_pooled.float() - if t5_out is not None and t5_out.dtype == torch.bfloat16: + if t5_out.dtype == torch.bfloat16: t5_out = t5_out.float() lg_out = lg_out.cpu().numpy() lg_pooled = lg_pooled.cpu().numpy() - if t5_out is not None: - t5_out = t5_out.cpu().numpy() + t5_out = t5_out.cpu().numpy() + + l_attn_mask = tokens_and_masks[3].cpu().numpy() + g_attn_mask = tokens_and_masks[4].cpu().numpy() + t5_attn_mask = tokens_and_masks[5].cpu().numpy() for i, info in enumerate(infos): lg_out_i = lg_out[i] - t5_out_i = t5_out[i] if t5_out is not None else None + t5_out_i = t5_out[i] lg_pooled_i = lg_pooled[i] + l_attn_mask_i = l_attn_mask[i] + g_attn_mask_i = g_attn_mask[i] + t5_attn_mask_i = t5_attn_mask[i] + apply_lg_attn_mask = self.apply_lg_attn_mask + apply_t5_attn_mask = self.apply_t5_attn_mask if self.cache_to_disk: - clip_l_attn_mask, clip_g_attn_mask, t5_attn_mask = tokens_and_masks[3:6] - clip_l_attn_mask_i = clip_l_attn_mask[i].cpu().numpy() - clip_g_attn_mask_i = clip_g_attn_mask[i].cpu().numpy() - t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() if t5_attn_mask is not None else None # shouldn't be None - kwargs = {} - if t5_out is not None: - kwargs["t5_out"] = t5_out_i np.savez( info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, - clip_l_attn_mask=clip_l_attn_mask_i, - clip_g_attn_mask=clip_g_attn_mask_i, + t5_out=t5_out_i, + clip_l_attn_mask=l_attn_mask_i, + clip_g_attn_mask=g_attn_mask_i, t5_attn_mask=t5_attn_mask_i, - **kwargs, + apply_lg_attn_mask=apply_lg_attn_mask, + apply_t5_attn_mask=apply_t5_attn_mask, ) else: - info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i) + info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i) class Sd3LatentsCachingStrategy(LatentsCachingStrategy): @@ -246,41 +282,3 @@ def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask if not train_util.HIGH_VRAM: train_util.clean_memory_on_device(vae.device) - - -if __name__ == "__main__": - # test code for Sd3TokenizeStrategy - # tokenizer = sd3_models.SD3Tokenizer() - strategy = Sd3TokenizeStrategy(256) - text = "hello world" - - l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) - # print(l_tokens.shape) - print(l_tokens) - print(g_tokens) - print(t5_tokens) - - texts = ["hello world", "the quick brown fox jumps over the lazy dog"] - l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") - g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") - t5_tokens_2 = strategy.t5xxl( - texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt" - ) - print(l_tokens_2) - print(g_tokens_2) - print(t5_tokens_2) - - # compare - print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0])) - print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0])) - print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0])) - - text = ",".join(["hello world! this is long text"] * 50) - l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) - print(l_tokens) - print(g_tokens) - print(t5_tokens) - - print(f"model max length l: {strategy.clip_l.model_max_length}") - print(f"model max length g: {strategy.clip_g.model_max_length}") - print(f"model max length t5: {strategy.t5xxl.model_max_length}") diff --git a/library/train_util.py b/library/train_util.py index 462c7a9a2..9ea1eec0e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5967,6 +5967,37 @@ def line_to_prompt_dict(line: str) -> dict: return prompt_dict +def load_prompts(prompt_file: str) -> List[Dict]: + # read prompts + if prompt_file.endswith(".txt"): + with open(prompt_file, "r", encoding="utf-8") as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif prompt_file.endswith(".toml"): + with open(prompt_file, "r", encoding="utf-8") as f: + data = toml.load(f) + prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] + elif prompt_file.endswith(".json"): + with open(prompt_file, "r", encoding="utf-8") as f: + prompts = json.load(f) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from library.train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + return prompts + + def sample_images_common( pipe_class, accelerator: Accelerator, diff --git a/library/utils.py b/library/utils.py index 8a0c782c0..ca0f904d2 100644 --- a/library/utils.py +++ b/library/utils.py @@ -13,12 +13,16 @@ import cv2 from PIL import Image import numpy as np +from safetensors.torch import load_file def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() +# region Logging + + def add_logging_arguments(parser): parser.add_argument( "--console_log_level", @@ -85,6 +89,11 @@ def setup_logging(args=None, log_level=None, reset=False): logger.info(msg_init) +# endregion + +# region PyTorch utils + + def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: """ Convert a string to a torch.dtype @@ -304,6 +313,35 @@ def _convert_float8(byte_tensor, dtype_str, shape): # return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape) raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") +def load_safetensors( + path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 +) -> dict[str, torch.Tensor]: + if disable_mmap: + # return safetensors.torch.load(open(path, "rb").read()) + # use experimental loader + # logger.info(f"Loading without mmap (experimental)") + state_dict = {} + with MemoryEfficientSafeOpen(path) as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) + return state_dict + else: + try: + state_dict = load_file(path, device=device) + except: + state_dict = load_file(path) # prevent device invalid Error + if dtype is not None: + for key in state_dict.keys(): + state_dict[key] = state_dict[key].to(dtype=dtype) + return state_dict + + + +# endregion + +# region Image utils + + def pil_resize(image, size, interpolation=Image.LANCZOS): has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False @@ -323,9 +361,9 @@ def pil_resize(image, size, interpolation=Image.LANCZOS): return resized_cv2 -# TODO make inf_utils.py - +# endregion +# TODO make inf_utils.py # region Gradual Latent hires fix diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index 630da7e08..d099fe18d 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -12,6 +12,7 @@ from safetensors.torch import safe_open, load_file from tqdm import tqdm from PIL import Image +from transformers import CLIPTextModelWithProjection, T5EncoderModel from library.device_utils import init_ipex, get_preferred_device @@ -25,11 +26,14 @@ logger = logging.getLogger(__name__) from library import sd3_models, sd3_utils, strategy_sd3 +from library.utils import load_safetensors -def get_noise(seed, latent): - generator = torch.manual_seed(seed) - return torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu").to(latent.dtype) +def get_noise(seed, latent, device="cpu"): + # generator = torch.manual_seed(seed) + generator = torch.Generator(device) + generator.manual_seed(seed) + return torch.randn(latent.size(), dtype=latent.dtype, layout=latent.layout, generator=generator, device=device) def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): @@ -59,7 +63,7 @@ def do_sample( neg_cond: Tuple[torch.Tensor, torch.Tensor], mmdit: sd3_models.MMDiT, steps: int, - guidance_scale: float, + cfg_scale: float, dtype: torch.dtype, device: str, ): @@ -71,7 +75,7 @@ def do_sample( latent = latent.to(dtype).to(device) - noise = get_noise(seed, latent).to(device) + noise = get_noise(seed, latent, device) model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 @@ -105,7 +109,7 @@ def do_sample( batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) pos_out, neg_out = batched.chunk(2) - denoised = neg_out + (pos_out - neg_out) * guidance_scale + denoised = neg_out + (pos_out - neg_out) * cfg_scale # print(denoised.shape) # d = to_d(x, sigma_hat, denoised) @@ -122,20 +126,89 @@ def do_sample( x = x.to(dtype) latent = x - scale_factor = 1.5305 - shift_factor = 0.0609 - # def process_out(self, latent): - # return (latent / self.scale_factor) + self.shift_factor - latent = (latent / scale_factor) + shift_factor + latent = vae.process_out(latent) return latent +def generate_image( + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, + clip_l: CLIPTextModelWithProjection, + clip_g: CLIPTextModelWithProjection, + t5xxl: T5EncoderModel, + steps: int, + prompt: str, + seed: int, + target_width: int, + target_height: int, + device: str, + negative_prompt: str, + cfg_scale: float, +): + # prepare embeddings + logger.info("Encoding prompts...") + + # TODO support one-by-one offloading + clip_l.to(device) + clip_g.to(device) + t5xxl.to(device) + + with torch.no_grad(): + tokens_and_masks = tokenize_strategy.tokenize(prompt) + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask + ) + cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + tokens_and_masks = tokenize_strategy.tokenize(negative_prompt) + lg_out, t5_out, pooled, neg_l_attn_mask, neg_g_attn_mask, neg_t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask + ) + neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # attn masks are not used currently + + if args.offload: + clip_l.to("cpu") + clip_g.to("cpu") + t5xxl.to("cpu") + + # generate image + logger.info("Generating image...") + mmdit.to(device) + latent_sampled = do_sample(target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, cfg_scale, sd3_dtype, device) + if args.offload: + mmdit.to("cpu") + + # latent to image + vae.to(device) + with torch.no_grad(): + image = vae.decode(latent_sampled) + + if args.offload: + vae.to("cpu") + + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) + decoded_np = decoded_np.astype(np.uint8) + out_image = Image.fromarray(decoded_np) + + # save image + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") + out_image.save(output_path) + + logger.info(f"Saved image to {output_path}") + + if __name__ == "__main__": target_height = 1024 target_width = 1024 # steps = 50 # 28 # 50 - guidance_scale = 5 + # cfg_scale = 5 # seed = 1 # None # 1 device = get_preferred_device() @@ -145,15 +218,17 @@ def do_sample( parser.add_argument("--clip_g", type=str, required=False) parser.add_argument("--clip_l", type=str, required=False) parser.add_argument("--t5xxl", type=str, required=False) - parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77") + parser.add_argument("--t5xxl_token_length", type=int, default=256, help="t5xxl token length, default: 256") parser.add_argument("--apply_lg_attn_mask", action="store_true") parser.add_argument("--apply_t5_attn_mask", action="store_true") parser.add_argument("--prompt", type=str, default="A photo of a cat") # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders parser.add_argument("--negative_prompt", type=str, default="") + parser.add_argument("--cfg_scale", type=float, default=5.0) + parser.add_argument("--offload", action="store_true", help="Offload to CPU") parser.add_argument("--output_dir", type=str, default=".") - parser.add_argument("--do_not_use_t5xxl", action="store_true") - parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch") + # parser.add_argument("--do_not_use_t5xxl", action="store_true") + # parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch") parser.add_argument("--fp16", action="store_true") parser.add_argument("--bf16", action="store_true") parser.add_argument("--seed", type=int, default=1) @@ -165,7 +240,9 @@ def do_sample( # default=[], # help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)", # ) - # parser.add_argument("--interactive", action="store_true") + parser.add_argument("--width", type=int, default=target_width) + parser.add_argument("--height", type=int, default=target_height) + parser.add_argument("--interactive", action="store_true") args = parser.parse_args() seed = args.seed @@ -177,185 +254,126 @@ def do_sample( elif args.bf16: sd3_dtype = torch.bfloat16 - # TODO test with separated safetenors files for each model + loading_device = "cpu" if args.offload else device # load state dict logger.info(f"Loading SD3 models from {args.ckpt_path}...") - state_dict = load_file(args.ckpt_path) - - if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: - # found clip_g: remove prefix "text_encoders.clip_g." - logger.info("clip_g is included in the checkpoint") - clip_g_sd = {} - prefix = "text_encoders.clip_g." - for k, v in list(state_dict.items()): - if k.startswith(prefix): - clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) - else: - logger.info(f"Lodaing clip_g from {args.clip_g}...") - clip_g_sd = load_file(args.clip_g) - for key in list(clip_g_sd.keys()): - clip_g_sd["transformer." + key] = clip_g_sd.pop(key) - - if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: - # found clip_l: remove prefix "text_encoders.clip_l." - logger.info("clip_l is included in the checkpoint") - clip_l_sd = {} - prefix = "text_encoders.clip_l." - for k, v in list(state_dict.items()): - if k.startswith(prefix): - clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) - else: - logger.info(f"Lodaing clip_l from {args.clip_l}...") - clip_l_sd = load_file(args.clip_l) - for key in list(clip_l_sd.keys()): - clip_l_sd["transformer." + key] = clip_l_sd.pop(key) - - if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: - # found t5xxl: remove prefix "text_encoders.t5xxl." - logger.info("t5xxl is included in the checkpoint") - if not args.do_not_use_t5xxl: - t5xxl_sd = {} - prefix = "text_encoders.t5xxl." - for k, v in list(state_dict.items()): - if k.startswith(prefix): - t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) - else: - logger.info("but not used") - for key in list(state_dict.keys()): - if key.startswith("text_encoders.t5xxl."): - state_dict.pop(key) - t5xxl_sd = None - elif args.t5xxl: - assert not args.do_not_use_t5xxl, "t5xxl is not used but specified" - logger.info(f"Lodaing t5xxl from {args.t5xxl}...") - t5xxl_sd = load_file(args.t5xxl) - for key in list(t5xxl_sd.keys()): - t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) - else: - logger.info("t5xxl is not used") - t5xxl_sd = None - - use_t5xxl = t5xxl_sd is not None - - # MMDiT and VAE - vae_sd = {} - vae_prefix = "first_stage_model." - mmdit_prefix = "model.diffusion_model." - for k, v in list(state_dict.items()): - if k.startswith(vae_prefix): - vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) - elif k.startswith(mmdit_prefix): - state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k) - - # load tokenizers - logger.info("Loading tokenizers...") - tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) - - # load models - # logger.info("Create MMDiT from SD3 checkpoint...") - # mmdit = sd3_utils.create_mmdit_from_sd3_checkpoint(state_dict) - logger.info("Create MMDiT") - mmdit = sd3_models.create_mmdit_sd3_medium_configs(args.attn_mode) - - logger.info("Loading state dict...") - info = mmdit.load_state_dict(state_dict) - logger.info(f"Loaded MMDiT: {info}") - - logger.info(f"Move MMDiT to {device} and {sd3_dtype}...") - mmdit.to(device, dtype=sd3_dtype) - mmdit.eval() - - # load VAE - logger.info("Create VAE") - vae = sd3_models.SDVAE() - logger.info("Loading state dict...") - info = vae.load_state_dict(vae_sd) - logger.info(f"Loaded VAE: {info}") - - logger.info(f"Move VAE to {device} and {sd3_dtype}...") - vae.to(device, dtype=sd3_dtype) - vae.eval() + # state_dict = load_file(args.ckpt_path) + state_dict = load_safetensors(args.ckpt_path, loading_device, disable_mmap=True, dtype=sd3_dtype) # load text encoders - logger.info("Create clip_l") - clip_l = sd3_models.create_clip_l(device, sd3_dtype, clip_l_sd) + clip_l = sd3_utils.load_clip_l(args.clip_l, sd3_dtype, loading_device, state_dict=state_dict) + clip_g = sd3_utils.load_clip_g(args.clip_g, sd3_dtype, loading_device, state_dict=state_dict) + t5xxl = sd3_utils.load_t5xxl(args.t5xxl, sd3_dtype, loading_device, state_dict=state_dict) - logger.info("Loading state dict...") - info = clip_l.load_state_dict(clip_l_sd) - logger.info(f"Loaded clip_l: {info}") + # MMDiT and VAE + vae = sd3_utils.load_vae(None, sd3_dtype, loading_device, state_dict=state_dict) + mmdit = sd3_utils.load_mmdit(state_dict, sd3_dtype, loading_device) + + clip_l.to(sd3_dtype) + clip_g.to(sd3_dtype) + t5xxl.to(sd3_dtype) + vae.to(sd3_dtype) + mmdit.to(sd3_dtype) + if not args.offload: + # make sure to move to the device: some tensors are created in the constructor on the CPU + clip_l.to(device) + clip_g.to(device) + t5xxl.to(device) + vae.to(device) + mmdit.to(device) - logger.info(f"Move clip_l to {device} and {sd3_dtype}...") - clip_l.to(device, dtype=sd3_dtype) clip_l.eval() - logger.info(f"Set attn_mode to {args.attn_mode}...") - clip_l.set_attn_mode(args.attn_mode) - - logger.info("Create clip_g") - clip_g = sd3_models.create_clip_g(device, sd3_dtype, clip_g_sd) - - logger.info("Loading state dict...") - info = clip_g.load_state_dict(clip_g_sd) - logger.info(f"Loaded clip_g: {info}") - - logger.info(f"Move clip_g to {device} and {sd3_dtype}...") - clip_g.to(device, dtype=sd3_dtype) clip_g.eval() - logger.info(f"Set attn_mode to {args.attn_mode}...") - clip_g.set_attn_mode(args.attn_mode) - - if use_t5xxl: - logger.info("Create t5xxl") - t5xxl = sd3_models.create_t5xxl(device, sd3_dtype, t5xxl_sd) - - logger.info("Loading state dict...") - info = t5xxl.load_state_dict(t5xxl_sd) - logger.info(f"Loaded t5xxl: {info}") - - logger.info(f"Move t5xxl to {device} and {sd3_dtype}...") - t5xxl.to(device, dtype=sd3_dtype) - # t5xxl.to("cpu", dtype=torch.float32) # run on CPU - t5xxl.eval() - logger.info(f"Set attn_mode to {args.attn_mode}...") - t5xxl.set_attn_mode(args.attn_mode) - else: - t5xxl = None + t5xxl.eval() + mmdit.eval() + vae.eval() - # prepare embeddings - logger.info("Encoding prompts...") + # load tokenizers + logger.info("Loading tokenizers...") + tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() - tokens_and_masks = tokenize_strategy.tokenize(args.prompt) - lg_out, t5_out, pooled = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask - ) - cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) - - tokens_and_masks = tokenize_strategy.tokenize(args.negative_prompt) - lg_out, t5_out, pooled = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask - ) - neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) - - # generate image - logger.info("Generating image...") - latent_sampled = do_sample( - target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, guidance_scale, sd3_dtype, device - ) - - # latent to image - with torch.no_grad(): - image = vae.decode(latent_sampled) - image = image.float() - image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] - decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) - decoded_np = decoded_np.astype(np.uint8) - out_image = Image.fromarray(decoded_np) - - # save image - output_dir = args.output_dir - os.makedirs(output_dir, exist_ok=True) - output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") - out_image.save(output_path) - - logger.info(f"Saved image to {output_path}") + if not args.interactive: + generate_image( + mmdit, + vae, + clip_l, + clip_g, + t5xxl, + args.steps, + args.prompt, + args.seed, + args.width, + args.height, + device, + args.negative_prompt, + args.cfg_scale, + ) + else: + # loop for interactive + width = args.width + height = args.height + steps = None + cfg_scale = args.cfg_scale + + while True: + print( + "Enter prompt (empty to exit). Options: --w --h --s --d " + " --n , `--n -` for empty negative prompt" + "Options are kept for the next prompt. Current options:" + f" width={width}, height={height}, steps={steps}, seed={seed}, cfg_scale={cfg_scale}" + ) + prompt = input() + if prompt == "": + break + + # parse options + options = prompt.split("--") + prompt = options[0].strip() + seed = None + negative_prompt = None + for opt in options[1:]: + try: + opt = opt.strip() + if opt.startswith("w"): + width = int(opt[1:].strip()) + elif opt.startswith("h"): + height = int(opt[1:].strip()) + elif opt.startswith("s"): + steps = int(opt[1:].strip()) + elif opt.startswith("d"): + seed = int(opt[1:].strip()) + # elif opt.startswith("m"): + # mutipliers = opt[1:].strip().split(",") + # if len(mutipliers) != len(lora_models): + # logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + # continue + # for i, lora_model in enumerate(lora_models): + # lora_model.set_multiplier(float(mutipliers[i])) + elif opt.startswith("n"): + negative_prompt = opt[1:].strip() + if negative_prompt == "-": + negative_prompt = "" + elif opt.startswith("c"): + cfg_scale = float(opt[1:].strip()) + except ValueError as e: + logger.error(f"Invalid option: {opt}, {e}") + + generate_image( + mmdit, + vae, + clip_l, + clip_g, + t5xxl, + steps if steps is not None else args.steps, + prompt, + seed if seed is not None else args.seed, + width, + height, + device, + negative_prompt if negative_prompt is not None else args.negative_prompt, + cfg_scale, + ) + + logger.info("Done!") diff --git a/sd3_train.py b/sd3_train.py index ef18c32c4..6336b4cf9 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -1,6 +1,7 @@ # training with captions import argparse +from concurrent.futures import ThreadPoolExecutor import copy import math import os @@ -11,6 +12,7 @@ from tqdm import tqdm import torch +from library import utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -38,7 +40,7 @@ ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments # from library.custom_train_functions import ( # apply_snr_weight, @@ -61,23 +63,13 @@ def train(args): if not args.skip_cache_check: args.skip_cache_check = args.skip_latents_validity_check - assert ( - not args.weighted_captions - ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" # assert ( # not args.train_text_encoder or not args.cache_text_encoder_outputs # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" - # # training text encoder is not supported - # assert ( - # not args.train_text_encoder - # ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" - - # # training without text encoder cache is not supported: because T5XXL must be cached - # assert ( - # args.cache_text_encoder_outputs - # ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" - assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), ( "when training text encoder, text encoder outputs must not be cached (except for T5XXL)" + " / text encoderの学習時はtext encoderの出力はキャッシュできません(t5xxlのみキャッシュすることは可能です)" @@ -90,13 +82,13 @@ def train(args): ) args.cache_text_encoder_outputs = True - # if args.block_lr: - # block_lrs = [float(lr) for lr in args.block_lr.split(",")] - # assert ( - # len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR - # ), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください" - # else: - # block_lrs = None + if args.train_t5xxl: + assert ( + args.train_text_encoder + ), "when training T5XXL, text encoder (CLIP-L/G) must be trained / T5XXLを学習するときはtext encoder (CLIP-L/G)も学習する必要があります" + assert ( + not args.cache_text_encoder_outputs + ), "when training T5XXL, t5xxl output must not be cached / T5XXLを学習するときはt5xxlの出力をキャッシュできません" cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -111,11 +103,6 @@ def train(args): ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) - # load tokenizer and prepare tokenize strategy - sd3_tokenizer = sd3_models.SD3Tokenizer(t5xxl_max_length=args.t5xxl_max_token_length) - sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length) - strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy) - # データセットを準備する if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) @@ -156,10 +143,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[sd3_tokenizer]) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, [sd3_tokenizer]) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -205,72 +192,56 @@ def train(args): # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) - vae_dtype = weight_dtype # torch.float32 if args.no_half_vae else weight_dtype # SD3 VAE works with fp16 - - t5xxl_dtype = weight_dtype - if args.t5xxl_dtype is not None: - if args.t5xxl_dtype == "fp16": - t5xxl_dtype = torch.float16 - elif args.t5xxl_dtype == "bf16": - t5xxl_dtype = torch.bfloat16 - elif args.t5xxl_dtype == "fp32" or args.t5xxl_dtype == "float": - t5xxl_dtype = torch.float32 - else: - raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") - t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device - - clip_dtype = weight_dtype # if not args.train_text_encoder else None # モデルを読み込む - attn_mode = "xformers" if args.xformers else "torch" - - assert ( - attn_mode == "torch" - ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" - - # SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying. - logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}") - device_to_load = accelerator.device if args.lowram else "cpu" - sd3_state_dict = sd3_utils.load_safetensors( - args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors - ) - # load VAE for caching latents - vae: sd3_models.SDVAE = None - if cache_latents: - vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) - vae.to(accelerator.device, dtype=vae_dtype) - vae.requires_grad_(False) - vae.eval() - - train_dataset_group.new_cache_latents(vae, accelerator) + # t5xxl_dtype = weight_dtype + # if args.t5xxl_dtype is not None: + # if args.t5xxl_dtype == "fp16": + # t5xxl_dtype = torch.float16 + # elif args.t5xxl_dtype == "bf16": + # t5xxl_dtype = torch.bfloat16 + # elif args.t5xxl_dtype == "fp32" or args.t5xxl_dtype == "float": + # t5xxl_dtype = torch.float32 + # else: + # raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") + # t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device + # clip_dtype = weight_dtype # if not args.train_text_encoder else None + + # if clip_l is not specified, the checkpoint must contain clip_l, so we load state dict here + # if full_fp16/bf16, model_dtype is casted to fp16/bf16. If not, model_dtype is None (float32). + # by loading with model_dtype, we can reduce memory usage. + model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) + if args.clip_l is None: + sd3_state_dict = utils.load_safetensors( + args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype + ) + else: + sd3_state_dict = None - vae.to("cpu") # if no sampling, vae can be deleted - clean_memory_on_device(accelerator.device) + # load tokenizer and prepare tokenize strategy + if args.t5xxl_max_token_length is None: + t5xxl_max_token_length = 256 # default value for T5XXL + else: + t5xxl_max_token_length = args.t5xxl_max_token_length - accelerator.wait_for_everyone() + sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy) # load clip_l, clip_g, t5xxl for caching text encoder outputs - # # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. - # mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( - # args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype - # ) - clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) - clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) - assert clip_l is not None, "clip_l is required / clip_lは必須です" - assert clip_g is not None, "clip_g is required / clip_gは必須です" - - t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load) - # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) - - # should be deleted after caching text encoder outputs when not training text encoder - # this strategy should not be used other than this process - text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() + # clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + # clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + clip_l = sd3_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + clip_g = sd3_utils.load_clip_g(args.clip_g, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + t5xxl = sd3_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + assert clip_l is not None and clip_g is not None and t5xxl is not None, "clip_l, clip_g, t5xxl must be specified" + + # prepare text encoding strategy + text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy(args.apply_lg_attn_mask, args.apply_t5_attn_mask) strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) # 学習を準備する:モデルを適切な状態にする - train_clip_l = False - train_clip_g = False + train_clip = False train_t5xxl = False if args.train_text_encoder: @@ -278,99 +249,135 @@ def train(args): if args.gradient_checkpointing: clip_l.gradient_checkpointing_enable() clip_g.gradient_checkpointing_enable() + if args.train_t5xxl: + t5xxl.gradient_checkpointing_enable() + lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train - train_clip_l = lr_te1 != 0 - train_clip_g = lr_te2 != 0 + lr_t5xxl = args.learning_rate_te3 if args.learning_rate_te3 is not None else args.learning_rate # 0 means not train + train_clip = lr_te1 != 0 or lr_te2 != 0 + train_t5xxl = lr_t5xxl != 0 and args.train_t5xxl - if not train_clip_l: - clip_l.to(weight_dtype) - if not train_clip_g: - clip_g.to(weight_dtype) - clip_l.requires_grad_(train_clip_l) - clip_g.requires_grad_(train_clip_g) - clip_l.train(train_clip_l) - clip_g.train(train_clip_g) + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + t5xxl.to(weight_dtype) + clip_l.requires_grad_(train_clip) + clip_g.requires_grad_(train_clip) + t5xxl.requires_grad_(train_t5xxl) else: + print("disable text encoder training") clip_l.to(weight_dtype) clip_g.to(weight_dtype) + t5xxl.to(weight_dtype) clip_l.requires_grad_(False) clip_g.requires_grad_(False) - clip_l.eval() - clip_g.eval() - - if t5xxl is not None: - t5xxl.to(t5xxl_dtype) t5xxl.requires_grad_(False) - t5xxl.eval() + lr_te1 = 0 + lr_te2 = 0 + lr_t5xxl = 0 # cache text encoder outputs sample_prompts_te_outputs = None if args.cache_text_encoder_outputs: - # Text Encodes are eval and no grad here clip_l.to(accelerator.device) clip_g.to(accelerator.device) - if t5xxl is not None: - t5xxl.to(t5xxl_device) + t5xxl.to(accelerator.device) + clip_l.eval() + clip_g.eval() + t5xxl.eval() text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, - train_clip_g or train_clip_l or args.use_t5xxl_cache_only, + train_clip or args.use_t5xxl_cache_only, # if clip is trained or t5xxl is cached, caching is partial args.apply_lg_attn_mask, args.apply_t5_attn_mask, ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) - clip_l.to(accelerator.device, dtype=weight_dtype) - clip_g.to(accelerator.device, dtype=weight_dtype) - if t5xxl is not None: - t5xxl.to(t5xxl_device, dtype=t5xxl_dtype) - with accelerator.autocast(): train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator) # cache sample prompt's embeddings to free text encoder's memory if args.sample_prompts is not None: logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") - prompts = sd3_train_utils.load_prompts(args.sample_prompts) + prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in prompts: for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: if p not in sample_prompts_te_outputs: logger.info(f"cache Text Encoder outputs for prompt: {p}") - tokens_list = sd3_tokenize_strategy.tokenize(p) + tokens_and_masks = sd3_tokenize_strategy.tokenize(p) sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [clip_l, clip_g, t5xxl], - tokens_list, + tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask, ) accelerator.wait_for_everyone() + # now we can delete Text Encoders to free memory + if args.use_t5xxl_cache_only: + clip_l = None + clip_g = None + t5xxl = None + + clean_memory_on_device(accelerator.device) + + # load VAE for caching latents + if sd3_state_dict is None: + sd3_state_dict = utils.load_safetensors( + args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype + ) + + vae = sd3_utils.load_vae(args.vae, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + if cache_latents: + # vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + + train_dataset_group.new_cache_latents(vae, accelerator) + + vae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + # load MMDIT - # if full_fp16/bf16, model_dtype is casted to fp16/bf16. If not, model_dtype is None (float32). - # by loading with model_dtype, we can reduce memory usage. - model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) - mmdit = sd3_train_utils.load_target_model("mmdit", args, sd3_state_dict, accelerator, attn_mode, model_dtype, device_to_load) + mmdit = sd3_utils.load_mmdit( + sd3_state_dict, + model_dtype, + "cpu", + ) + + # attn_mode = "xformers" if args.xformers else "torch" + # assert ( + # attn_mode == "torch" + # ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + + # SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying. + logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}") + device_to_load = accelerator.device if args.lowram else "cpu" + sd3_state_dict = utils.load_safetensors(args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors) + if args.gradient_checkpointing: mmdit.enable_gradient_checkpointing() train_mmdit = args.learning_rate != 0 mmdit.requires_grad_(train_mmdit) if not train_mmdit: - mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdie will not be prepared + mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdit will not be prepared if not cache_latents: - # load VAE here if not cached - vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) + # move to accelerator device vae.requires_grad_(False) vae.eval() - vae.to(accelerator.device, dtype=vae_dtype) + vae.to(accelerator.device, dtype=weight_dtype) mmdit.requires_grad_(train_mmdit) if not train_mmdit: @@ -394,19 +401,24 @@ def train(args): training_models = [] params_to_optimize = [] - # if train_unet: + param_names = [] training_models.append(mmdit) - # if block_lrs is None: params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate}) - # else: - # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) - - # if train_clip_l: - # training_models.append(clip_l) - # params_to_optimize.append({"params": list(clip_l.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) - # if train_clip_g: - # training_models.append(clip_g) - # params_to_optimize.append({"params": list(clip_g.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + param_names.append([n for n, _ in mmdit.named_parameters()]) + + if train_clip: + if lr_te1 > 0: + training_models.append(clip_l) + params_to_optimize.append({"params": list(clip_l.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + param_names.append([n for n, _ in clip_l.named_parameters()]) + if lr_te2 > 0: + training_models.append(clip_g) + params_to_optimize.append({"params": list(clip_g.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + param_names.append([n for n, _ in clip_g.named_parameters()]) + if train_t5xxl: + training_models.append(t5xxl) + params_to_optimize.append({"params": list(t5xxl.parameters()), "lr": args.learning_rate_te3 or args.learning_rate}) + param_names.append([n for n, _ in t5xxl.named_parameters()]) # calculate number of trainable parameters n_params = 0 @@ -414,47 +426,49 @@ def train(args): for p in group["params"]: n_params += p.numel() - accelerator.print(f"train mmdit: {train_mmdit}") # , clip_l: {train_clip_l}, clip_g: {train_clip_g}") + accelerator.print(f"train mmdit: {train_mmdit} , clip:{train_clip}, t5xxl:{train_t5xxl}") accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of trainable parameters: {n_params}") # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html - # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. # This balances memory usage and management complexity. - # calculate total number of parameters - n_total_params = sum(len(params["params"]) for params in params_to_optimize) - params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) - - # split params into groups, keeping the learning rate the same for all params in a group - # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + # split params into groups for mmdit. clip_l, clip_g, t5xxl are in each group grouped_params = [] - param_group = [] - param_group_lr = -1 - for group in params_to_optimize: - lr = group["lr"] - for p in group["params"]: - # if the learning rate is different for different params, start a new group - if lr != param_group_lr: - if param_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) - param_group = [] - param_group_lr = lr - - param_group.append(p) - - # if the group has enough parameters, start a new group - if len(param_group) == params_per_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) - param_group = [] - param_group_lr = -1 - - if param_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = {} + group = params_to_optimize[0] + named_parameters = list(mmdit.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # joint or other + if np[0].startswith("joint_blocks"): + block_idx = int(np[0].split(".")[1]) + block_type = "joint" + else: + block_idx = -1 + + param_group_key = (block_type, block_idx) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + grouped_params.extend(params_to_optimize[1:]) # add clip_l, clip_g, t5xxl if they are trained # prepare optimizers for each group optimizers = [] @@ -463,10 +477,15 @@ def train(args): optimizers.append(optimizer) optimizer = optimizers[0] # avoid error in the following code - logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset @@ -497,7 +516,7 @@ def train(args): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: # prepare lr schedulers for each optimizer lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] lr_scheduler = lr_schedulers[0] # avoid error in the following code @@ -511,18 +530,22 @@ def train(args): ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" accelerator.print("enable full fp16 training.") mmdit.to(weight_dtype) - clip_l.to(weight_dtype) - clip_g.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + if clip_g is not None: + clip_g.to(weight_dtype) if t5xxl is not None: - t5xxl.to(weight_dtype) # TODO check works with fp16 or not + t5xxl.to(weight_dtype) elif args.full_bf16: assert ( args.mixed_precision == "bf16" ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" accelerator.print("enable full bf16 training.") mmdit.to(weight_dtype) - clip_l.to(weight_dtype) - clip_g.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + if clip_g is not None: + clip_g.to(weight_dtype) if t5xxl is not None: t5xxl.to(weight_dtype) @@ -533,14 +556,7 @@ def train(args): # clip_l.text_model.final_layer_norm.requires_grad_(False) # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する - if args.cache_text_encoder_outputs: - # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 - clip_l.to("cpu", dtype=torch.float32) - clip_g.to("cpu", dtype=torch.float32) - if t5xxl is not None: - t5xxl.to("cpu", dtype=torch.float32) - clean_memory_on_device(accelerator.device) - else: + if not args.cache_text_encoder_outputs: # make sure Text Encoders are on GPU # TODO support CPU for text encoders clip_l.to(accelerator.device) @@ -548,18 +564,11 @@ def train(args): if t5xxl is not None: t5xxl.to(accelerator.device) - # TODO cache sample prompt's embeddings to free text encoder's memory - if args.cache_text_encoder_outputs: - if not args.save_t5xxl: - t5xxl = None # free memory clean_memory_on_device(accelerator.device) if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model( - args, - mmdit=mmdit, - clip_l=clip_l if train_clip_l else None, - clip_g=clip_g if train_clip_g else None, + args, mmdit=mmdit, clip_l=clip_l if train_clip else None, clip_g=clip_g if train_clip else None ) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -571,10 +580,11 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい if train_mmdit: mmdit = accelerator.prepare(mmdit) - if train_clip_l: + if train_clip: clip_l = accelerator.prepare(clip_l) - if train_clip_g: clip_g = accelerator.prepare(clip_g) + if train_t5xxl: + t5xxl = accelerator.prepare(t5xxl) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -586,24 +596,110 @@ def train(args): # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) + # memory efficient block swapping + + def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, device): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda, dvc): + # print(f"Backward: Move block {bidx_to_cpu} to CPU") + block_to_cpu = block_to_cpu.to("cpu", non_blocking=True) + torch.cuda.empty_cache() + + block_to_cuda = block_to_cuda.to(dvc, non_blocking=True) + torch.cuda.synchronize() + return bidx_to_cpu, bidx_to_cuda + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] + + futures[block_idx_to_cuda] = thread_pool.submit( + move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda, device + ) + + def wait_blocks_move(block_idx, futures): + if block_idx not in futures: + return + future = futures.pop(block_idx) + future.result() + if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - for param_group in optimizer.param_groups: - for parameter in param_group["params"]: + + blocks_to_swap = args.blocks_to_swap + num_blocks = len(accelerator.unwrap_model(mmdit).joint_blocks) + handled_block_indices = set() + + n = 1 # only asynchronous purpose, no need to increase this number + # n = 2 + # n = max(1, os.cpu_count() // 2) + thread_pool = ThreadPoolExecutor(max_workers=n) + futures = {} + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: + grad_hook = None + + if blocks_to_swap: + is_block = param_name.startswith("double_blocks") + if is_block: + block_idx = int(param_name.split(".")[1]) + if block_idx not in handled_block_indices: + # swap following (already backpropagated) block + handled_block_indices.add(block_idx) + + # if n blocks were already backpropagated + num_blocks_propagated = num_blocks - block_idx - 1 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap + waiting = block_idx > 0 and block_idx <= blocks_to_swap + if swapping or waiting: + block_idx_to_cpu = num_blocks - num_blocks_propagated + block_idx_to_cuda = blocks_to_swap - num_blocks_propagated + block_idx_to_wait = block_idx - 1 + + # create swap hook + def create_swap_grad_hook( + bidx_to_cpu, bidx_to_cuda, bidx_to_wait, bidx: int, swpng: bool, wtng: bool + ): + def __grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + if swpng: + submit_move_blocks( + futures, + thread_pool, + bidx_to_cpu, + bidx_to_cuda, + mmdit.joint_blocks, + accelerator.device, + ) + if wtng: + wait_blocks_move(bidx_to_wait, futures) + + return __grad_hook + + grad_hook = create_swap_grad_hook( + block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, block_idx, swapping, waiting + ) + + if grad_hook is None: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None - def __grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None + grad_hook = __grad_hook - parameter.register_post_accumulate_grad_hook(__grad_hook) + parameter.register_post_accumulate_grad_hook(grad_hook) - elif args.fused_optimizer_groups: + elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers for i in range(1, len(optimizers)): optimizers[i] = accelerator.prepare(optimizers[i]) @@ -618,22 +714,59 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} + blocks_to_swap = args.blocks_to_swap + num_blocks = len(accelerator.unwrap_model(mmdit).joint_blocks) + + n = 1 # only asynchronous purpose, no need to increase this number + # n = max(1, os.cpu_count() // 2) + thread_pool = ThreadPoolExecutor(max_workers=n) + futures = {} + for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: - - def optimizer_hook(parameter: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(parameter, args.max_grad_norm) - - i = parameter_optimizer_map[parameter] - optimizer_hooked_count[i] += 1 - if optimizer_hooked_count[i] == num_parameters_per_group[i]: - optimizers[i].step() - optimizers[i].zero_grad(set_to_none=True) - - parameter.register_post_accumulate_grad_hook(optimizer_hook) + block_type, block_idx = block_types_and_indices[opt_idx] + + def create_optimizer_hook(btype, bidx): + def optimizer_hook(parameter: torch.Tensor): + # print(f"optimizer_hook: {btype}, {bidx}") + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + # swap blocks if necessary + if blocks_to_swap and btype == "joint": + num_blocks_propagated = num_blocks - bidx + + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap + waiting = bidx > 0 and bidx <= blocks_to_swap + + if swapping: + block_idx_to_cpu = num_blocks - num_blocks_propagated + block_idx_to_cuda = blocks_to_swap - num_blocks_propagated + # print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}") + submit_move_blocks( + futures, + thread_pool, + block_idx_to_cpu, + block_idx_to_cuda, + mmdit.joint_blocks, + accelerator.device, + ) + + if waiting: + block_idx_to_wait = bidx - 1 + wait_blocks_move(block_idx_to_wait, futures) + + return optimizer_hook + + parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx)) parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 @@ -661,17 +794,9 @@ def optimizer_hook(parameter: torch.Tensor): progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - # noise_scheduler = DDPMScheduler( - # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - # ) - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) noise_scheduler_copy = copy.deepcopy(noise_scheduler) - # prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - # if args.zero_terminal_snr: - # custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: @@ -685,60 +810,13 @@ def optimizer_hook(parameter: torch.Tensor): ) # For --sample_at_first + optimizer_eval_fn() sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs) + optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb accelerator.log({}, step=0) - # following function will be moved to sd3_train_utils - - def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): - sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) - schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) - timesteps = timesteps.to(accelerator.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - return sigma - - def compute_density_for_timestep_sampling( - weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None - ): - """Compute the density for sampling the timesteps when doing SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") - u = torch.nn.functional.sigmoid(u) - elif weighting_scheme == "mode": - u = torch.rand(size=(batch_size,), device="cpu") - u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) - else: - u = torch.rand(size=(batch_size,), device="cpu") - return u - - def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): - """Computes loss weighting scheme for SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "sigma_sqrt": - weighting = (sigmas**-2.0).float() - elif weighting_scheme == "cosmap": - bot = 1 - 2 * sigmas + 2 * sigmas**2 - weighting = 2 / (math.pi * bot) - else: - weighting = torch.ones_like(sigmas) - return weighting - loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0 for epoch in range(num_train_epochs): @@ -751,16 +829,16 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): for step, batch in enumerate(train_dataloader): current_step.value = global_step - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) else: with torch.no_grad(): # encode images to latents. images are [-1, 1] - latents = vae.encode(batch["images"].to(vae_dtype)).to(weight_dtype) + latents = vae.encode(batch["images"]) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): @@ -772,7 +850,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: - lg_out, t5_out, lg_pooled = text_encoder_outputs_list + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list if args.use_t5xxl_cache_only: lg_out = None lg_pooled = None @@ -781,7 +859,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): t5_out = None lg_pooled = None - if lg_out is None or (train_clip_l or train_clip_g): + if lg_out is None: # not cached or training, so get from text encoders input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): @@ -811,21 +889,10 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image - # for weighting schemes where we sample timesteps non-uniformly - u = compute_density_for_timestep_sampling( - weighting_scheme=args.weighting_scheme, - batch_size=bsz, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype ) - indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) - - # Add noise according to flow matching. - sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents # debug: NaN check for all inputs if torch.any(torch.isnan(noisy_model_input)): @@ -840,6 +907,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): # call model with accelerator.autocast(): + # TODO support attention mask model_pred = mmdit(noisy_model_input, timesteps, context=context, y=lg_pooled) # Follow: Section 5 of https://arxiv.org/abs/2206.00364. @@ -848,21 +916,34 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss target = latents - # Compute regular loss. TODO simplify this - loss = torch.mean( - (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), - 1, + # # Compute regular loss. TODO simplify this + # loss = torch.mean( + # (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + # 1, + # ) + # calculate loss + loss = train_util.conditional_loss( + model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None ) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + if weighting is not None: + loss = loss * weighting + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights loss = loss.mean() accelerator.backward(loss) - if not (args.fused_backward_pass or args.fused_optimizer_groups): + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] for m in training_models: @@ -875,7 +956,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): else: # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook lr_scheduler.step() - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: for i in range(1, len(optimizers)): lr_schedulers[i].step() @@ -884,6 +965,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): progress_bar.update(1) global_step += 1 + optimizer_eval_fn() sd3_train_utils.sample_images( accelerator, args, None, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs ) @@ -900,12 +982,13 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): epoch, num_train_epochs, global_step, - accelerator.unwrap_model(clip_l) if args.save_clip else None, - accelerator.unwrap_model(clip_g) if args.save_clip else None, - accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None, - accelerator.unwrap_model(mmdit), + accelerator.unwrap_model(clip_l) if train_clip else None, + accelerator.unwrap_model(clip_g) if train_clip else None, + accelerator.unwrap_model(t5xxl) if train_t5xxl else None, + accelerator.unwrap_model(mmdit) if train_mmdit else None, vae, ) + optimizer_train_fn() current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if len(accelerator.trackers) > 0: @@ -928,6 +1011,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): accelerator.wait_for_everyone() + optimizer_eval_fn() if args.save_every_n_epochs is not None: if accelerator.is_main_process: sd3_train_utils.save_sd3_model_on_epoch_end_or_stepwise( @@ -938,10 +1022,10 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): epoch, num_train_epochs, global_step, - accelerator.unwrap_model(clip_l) if args.save_clip else None, - accelerator.unwrap_model(clip_g) if args.save_clip else None, - accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None, - accelerator.unwrap_model(mmdit), + accelerator.unwrap_model(clip_l) if train_clip else None, + accelerator.unwrap_model(clip_g) if train_clip else None, + accelerator.unwrap_model(t5xxl) if train_t5xxl else None, + accelerator.unwrap_model(mmdit) if train_mmdit else None, vae, ) @@ -958,6 +1042,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): t5xxl = accelerator.unwrap_model(t5xxl) accelerator.end_training() + optimizer_eval_fn() if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) @@ -970,10 +1055,10 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): save_dtype, epoch, global_step, - clip_l if args.save_clip else None, - clip_g if args.save_clip else None, - t5xxl if args.save_t5xxl else None, - mmdit, + accelerator.unwrap_model(clip_l) if train_clip else None, + accelerator.unwrap_model(clip_g) if train_clip else None, + accelerator.unwrap_model(t5xxl) if train_t5xxl else None, + accelerator.unwrap_model(mmdit) if train_mmdit else None, vae, ) logger.info("model saved.") @@ -991,13 +1076,13 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) + add_custom_train_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) parser.add_argument( "--train_text_encoder", action="store_true", help="train text encoder (CLIP-L and G) / text encoderも学習する" ) - # parser.add_argument("--train_t5xxl", action="store_true", help="train T5-XXL / T5-XXLも学習する") + parser.add_argument("--train_t5xxl", action="store_true", help="train T5-XXL / T5-XXLも学習する") parser.add_argument( "--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする" ) @@ -1018,19 +1103,24 @@ def setup_parser() -> argparse.ArgumentParser: help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", ) - # TE training is disabled temporarily - # parser.add_argument( - # "--learning_rate_te1", - # type=float, - # default=None, - # help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率", - # ) - # parser.add_argument( - # "--learning_rate_te2", - # type=float, - # default=None, - # help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", - # ) + parser.add_argument( + "--learning_rate_te1", + type=float, + default=None, + help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率", + ) + parser.add_argument( + "--learning_rate_te2", + type=float, + default=None, + help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", + ) + parser.add_argument( + "--learning_rate_te3", + type=float, + default=None, + help="learning rate for text encoder 3 (T5-XXL) / text encoder 3 (T5-XXL)の学習率", + ) # parser.add_argument( # "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" @@ -1047,22 +1137,22 @@ def setup_parser() -> argparse.ArgumentParser: # help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " # + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", # ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) parser.add_argument( "--fused_optimizer_groups", type=int, default=None, - help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + help="[DOES NOT WORK] number of optimizer groups for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizerグループ数", ) parser.add_argument( "--skip_latents_validity_check", action="store_true", help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) - parser.add_argument( - "--skip_cache_check", - action="store_true", - help="skip cache (latents and text encoder outputs) check / キャッシュ(latentsとtext encoder outputs)のチェックをスキップする", - ) parser.add_argument( "--num_last_block_to_freeze", type=int, From e3c43bda49ec8c5a5cb784e29f8610f1ebff0a66 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 24 Oct 2024 20:35:47 +0900 Subject: [PATCH 190/348] reduce memory usage in sample image generation --- library/sd3_train_utils.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 9282482d9..af8ecf2c9 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -402,9 +402,6 @@ def sample_images( except Exception: pass - org_vae_device = vae.device # will be on cpu - vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device - if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. with torch.no_grad(): @@ -450,8 +447,6 @@ def sample_images( if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) - vae.to(org_vae_device) - clean_memory_on_device(accelerator.device) @@ -531,12 +526,19 @@ def sample_image_inference( neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # sample image - latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device) - latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype)) + clean_memory_on_device(accelerator.device) + with accelerator.autocast(): + latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device) # latent to image - with torch.no_grad(): - image = vae.decode(latents) + clean_memory_on_device(accelerator.device) + org_vae_device = vae.device # will be on cpu + vae.to(accelerator.device) + latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype)) + image = vae.decode(latents) + vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + image = image.float() image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) From 0286114bd208717510b537d9acd940db48a158f3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 24 Oct 2024 21:28:42 +0900 Subject: [PATCH 191/348] support SD3.5L, fix final saving --- sd3_train.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index 6336b4cf9..d4ab13a34 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -321,7 +321,7 @@ def train(args): accelerator.wait_for_everyone() # now we can delete Text Encoders to free memory - if args.use_t5xxl_cache_only: + if not args.use_t5xxl_cache_only: clip_l = None clip_g = None t5xxl = None @@ -330,6 +330,7 @@ def train(args): # load VAE for caching latents if sd3_state_dict is None: + logger.info(f"load state dict for MMDiT and VAE from {args.pretrained_model_name_or_path}") sd3_state_dict = utils.load_safetensors( args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype ) @@ -360,11 +361,6 @@ def train(args): # attn_mode == "torch" # ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" - # SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying. - logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}") - device_to_load = accelerator.device if args.lowram else "cpu" - sd3_state_dict = utils.load_safetensors(args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors) - if args.gradient_checkpointing: mmdit.enable_gradient_checkpointing() @@ -555,7 +551,7 @@ def train(args): # clip_l.text_model.encoder.layers[-1].requires_grad_(False) # clip_l.text_model.final_layer_norm.requires_grad_(False) - # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する + # move Text Encoders to GPU if not caching outputs if not args.cache_text_encoder_outputs: # make sure Text Encoders are on GPU # TODO support CPU for text encoders @@ -817,6 +813,13 @@ def optimizer_hook(parameter: torch.Tensor): # log empty object to commit the sample images to wandb accelerator.log({}, step=0) + # show model device and dtype + logger.info(f"mmdit device: {mmdit.device}, dtype: {mmdit.dtype}" if mmdit else "mmdit is None") + logger.info(f"clip_l device: {clip_l.device}, dtype: {clip_l.dtype}" if clip_l else "clip_l is None") + logger.info(f"clip_g device: {clip_g.device}, dtype: {clip_g.dtype}" if clip_g else "clip_g is None") + logger.info(f"t5xxl device: {t5xxl.device}, dtype: {t5xxl.dtype}" if t5xxl else "t5xxl is None") + logger.info(f"vae device: {vae.device}, dtype: {vae.dtype}" if vae is not None else "vae is None") + loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0 for epoch in range(num_train_epochs): @@ -1055,10 +1058,10 @@ def optimizer_hook(parameter: torch.Tensor): save_dtype, epoch, global_step, - accelerator.unwrap_model(clip_l) if train_clip else None, - accelerator.unwrap_model(clip_g) if train_clip else None, - accelerator.unwrap_model(t5xxl) if train_t5xxl else None, - accelerator.unwrap_model(mmdit) if train_mmdit else None, + clip_l if train_clip else None, + clip_g if train_clip else None, + t5xxl if train_t5xxl else None, + mmdit if train_mmdit else None, vae, ) logger.info("model saved.") @@ -1153,6 +1156,16 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) + parser.add_argument( + "--blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of blocks (~640MB) to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) parser.add_argument( "--num_last_block_to_freeze", type=int, From f8c5146d71b1c40b69d80b7ea18c21bbb66b84f3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 24 Oct 2024 22:02:05 +0900 Subject: [PATCH 192/348] support block swap with fused_optimizer_pass --- library/sd3_models.py | 79 +++++++++++++++++++++++++++++++++++++++++-- sd3_train.py | 19 +++++++++-- 2 files changed, 94 insertions(+), 4 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index c81aa4794..e5c5887a9 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -4,6 +4,7 @@ # and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! from ast import Tuple +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from functools import partial import math @@ -17,6 +18,8 @@ from torch.utils.checkpoint import checkpoint from transformers import CLIPTokenizer, T5TokenizerFast +from library.device_utils import clean_memory_on_device + from .utils import setup_logging setup_logging() @@ -848,6 +851,35 @@ def cropped_pos_embed(self, h, w, device=None): spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) return spatial_pos_embed + def enable_block_swap(self, num_blocks: int): + self.blocks_to_swap = num_blocks + + n = 1 # async block swap. 1 is enough + self.thread_pool = ThreadPoolExecutor(max_workers=n) + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu + if self.blocks_to_swap: + save_blocks = self.joint_blocks + self.joint_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.joint_blocks = save_blocks + + def prepare_block_swap_before_forward(self): + # make: first n blocks are on cuda, and last n blocks are on cpu + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + # raise ValueError("Block swap is not enabled.") + return + num_blocks = len(self.joint_blocks) + for i in range(num_blocks - self.blocks_to_swap): + self.joint_blocks[i].to(self.device) + for i in range(num_blocks - self.blocks_to_swap, num_blocks): + self.joint_blocks[i].to("cpu") + clean_memory_on_device(self.device) + def forward( self, x: torch.Tensor, @@ -881,8 +913,51 @@ def forward( 1, ) - for block in self.joint_blocks: - context, x = block(context, x, c) + if not self.blocks_to_swap: + for block in self.joint_blocks: + context, x = block(context, x, c) + else: + futures = {} + + def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + # print(f"Moving {bidx_to_cpu} to cpu.") + block_to_cpu.to("cpu", non_blocking=True) + torch.cuda.empty_cache() + + # print(f"Moving {bidx_to_cuda} to cuda.") + block_to_cuda.to(self.device, non_blocking=True) + + torch.cuda.synchronize() + # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") + return block_idx_to_cpu, block_idx_to_cuda + + block_to_cpu = self.joint_blocks[block_idx_to_cpu] + block_to_cuda = self.joint_blocks[block_idx_to_cuda] + # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") + return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) + + def wait_for_blocks_move(block_idx, ftrs): + if block_idx not in ftrs: + return + # print(f"Waiting for move blocks: {block_idx}") + # start_time = time.perf_counter() + ftr = ftrs.pop(block_idx) + ftr.result() + # torch.cuda.synchronize() + # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") + + for block_idx, block in enumerate(self.joint_blocks): + wait_for_blocks_move(block_idx, futures) + + context, x = block(context, x, c) + + if block_idx < self.blocks_to_swap: + block_idx_to_cpu = block_idx + block_idx_to_cuda = len(self.joint_blocks) - self.blocks_to_swap + block_idx + future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) + futures[block_idx_to_cuda] = future + x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify return x[:, :, :H, :W] diff --git a/sd3_train.py b/sd3_train.py index d4ab13a34..5e2efa6f8 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -369,6 +369,14 @@ def train(args): if not train_mmdit: mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdit will not be prepared + # block swap + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + mmdit.enable_block_swap(args.blocks_to_swap) + if not cache_latents: # move to accelerator device vae.requires_grad_(False) @@ -575,7 +583,9 @@ def train(args): else: # acceleratorがなんかよろしくやってくれるらしい if train_mmdit: - mmdit = accelerator.prepare(mmdit) + mmdit = accelerator.prepare(mmdit, device_placement=[not is_swapping_blocks]) + if is_swapping_blocks: + accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage if train_clip: clip_l = accelerator.prepare(clip_l) clip_g = accelerator.prepare(clip_g) @@ -600,8 +610,10 @@ def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda, dvc): block_to_cpu = block_to_cpu.to("cpu", non_blocking=True) torch.cuda.empty_cache() + # print(f"Backward: Move block {bidx_to_cuda} to CUDA") block_to_cuda = block_to_cuda.to(dvc, non_blocking=True) torch.cuda.synchronize() + # print(f"Backward: Done moving blocks {bidx_to_cpu} and {bidx_to_cuda}") return bidx_to_cpu, bidx_to_cuda block_to_cpu = blocks[block_idx_to_cpu] @@ -639,7 +651,7 @@ def wait_blocks_move(block_idx, futures): grad_hook = None if blocks_to_swap: - is_block = param_name.startswith("double_blocks") + is_block = param_name.startswith("joint_blocks") if is_block: block_idx = int(param_name.split(".")[1]) if block_idx not in handled_block_indices: @@ -805,6 +817,9 @@ def optimizer_hook(parameter: torch.Tensor): init_kwargs=init_kwargs, ) + if is_swapping_blocks: + accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward() + # For --sample_at_first optimizer_eval_fn() sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs) From d2c549d7b2a9bb3e70b5af8539fd744b474a9607 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Fri, 25 Oct 2024 21:58:31 +0900 Subject: [PATCH 193/348] support SD3 LoRA --- library/sd3_models.py | 3 + library/sd3_train_utils.py | 113 +++-- library/sd3_utils.py | 2 +- networks/lora_sd3.py | 826 +++++++++++++++++++++++++++++++++++++ sd3_train.py | 30 +- sd3_train_network.py | 427 +++++++++++++++++++ train_network.py | 2 + 7 files changed, 1335 insertions(+), 68 deletions(-) create mode 100644 networks/lora_sd3.py create mode 100644 sd3_train_network.py diff --git a/library/sd3_models.py b/library/sd3_models.py index e5c5887a9..5d09f74e8 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -761,6 +761,9 @@ def __init__( self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels) # self.initialize_weights() + self.blocks_to_swap = None + self.thread_pool: Optional[ThreadPoolExecutor] = None + @property def model_type(self): return self._model_type diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index af8ecf2c9..e3c649f73 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -198,6 +198,23 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用", ) + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=256, + help="maximum token length for T5-XXL. 256 is the default value / T5-XXLの最大トークン長。デフォルトは256", + ) + parser.add_argument( + "--apply_lg_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + ) + # copy from Diffusers parser.add_argument( "--weighting_scheme", @@ -317,36 +334,36 @@ def do_sample( x = noise_scaled.to(device).to(dtype) # print(x.shape) - with torch.no_grad(): - for i in tqdm(range(len(sigmas) - 1)): - sigma_hat = sigmas[i] + # with torch.no_grad(): + for i in tqdm(range(len(sigmas) - 1)): + sigma_hat = sigmas[i] - timestep = model_sampling.timestep(sigma_hat).float() - timestep = torch.FloatTensor([timestep, timestep]).to(device) + timestep = model_sampling.timestep(sigma_hat).float() + timestep = torch.FloatTensor([timestep, timestep]).to(device) - x_c_nc = torch.cat([x, x], dim=0) - # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + x_c_nc = torch.cat([x, x], dim=0) + # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) - model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) - model_output = model_output.float() - batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + model_output = model_output.float() + batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) - pos_out, neg_out = batched.chunk(2) - denoised = neg_out + (pos_out - neg_out) * guidance_scale - # print(denoised.shape) + pos_out, neg_out = batched.chunk(2) + denoised = neg_out + (pos_out - neg_out) * guidance_scale + # print(denoised.shape) - # d = to_d(x, sigma_hat, denoised) - dims_to_append = x.ndim - sigma_hat.ndim - sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] - # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) - """Converts a denoiser output to a Karras ODE derivative.""" - d = (x - denoised) / sigma_hat_dims + # d = to_d(x, sigma_hat, denoised) + dims_to_append = x.ndim - sigma_hat.ndim + sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] + # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) + """Converts a denoiser output to a Karras ODE derivative.""" + d = (x - denoised) / sigma_hat_dims - dt = sigmas[i + 1] - sigma_hat + dt = sigmas[i + 1] - sigma_hat - # Euler method - x = x + d * dt - x = x.to(dtype) + # Euler method + x = x + d * dt + x = x.to(dtype) return x @@ -378,7 +395,7 @@ def sample_images( logger.info("") logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") - if not os.path.isfile(args.sample_prompts): + if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None: logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return @@ -386,7 +403,7 @@ def sample_images( # unwrap unet and text_encoder(s) mmdit = accelerator.unwrap_model(mmdit) - text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + text_encoders = None if text_encoders is None else [accelerator.unwrap_model(te) for te in text_encoders] # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) prompts = train_util.load_prompts(args.sample_prompts) @@ -404,7 +421,7 @@ def sample_images( if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. - with torch.no_grad(): + with torch.no_grad(), accelerator.autocast(): for prompt_dict in prompts: sample_image_inference( accelerator, @@ -506,29 +523,39 @@ def sample_image_inference( tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() - if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: - te_outputs = sample_prompts_te_outputs[prompt] - else: - l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(prompt) - te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) - - lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = te_outputs + def encode_prompt(prpt): + text_encoder_conds = [] + if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs: + text_encoder_conds = sample_prompts_te_outputs[prpt] + print(f"Using cached text encoder outputs for prompt: {prpt}") + if text_encoders is not None: + print(f"Encoding prompt: {prpt}") + tokens_and_masks = tokenize_strategy.tokenize(prpt) + # strategy has apply_t5_attn_mask option + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + return text_encoder_conds + + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(prompt) cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # encode negative prompts - if sample_prompts_te_outputs and negative_prompt in sample_prompts_te_outputs: - neg_te_outputs = sample_prompts_te_outputs[negative_prompt] - else: - l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(negative_prompt) - neg_te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) - - lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = neg_te_outputs + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(negative_prompt) neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # sample image clean_memory_on_device(accelerator.device) - with accelerator.autocast(): - latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device) + with accelerator.autocast(), torch.no_grad(): + # mmdit may be fp8, so we need weight_dtype here. vae is always in that dtype. + latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, vae.dtype, accelerator.device) # latent to image clean_memory_on_device(accelerator.device) @@ -538,7 +565,7 @@ def sample_image_inference( image = vae.decode(latents) vae.to(org_vae_device) clean_memory_on_device(accelerator.device) - + image = image.float() image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 9ad995d81..71e50de36 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -91,7 +91,7 @@ def load_mmdit( mmdit = sd3_models.create_sd3_mmdit(params, attn_mode) logger.info("Loading state dict...") - info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype) + info = mmdit.load_state_dict(mmdit_sd, strict=False, assign=True) logger.info(f"Loaded MMDiT: {info}") return mmdit diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py new file mode 100644 index 000000000..cbabf8da0 --- /dev/null +++ b/networks/lora_sd3.py @@ -0,0 +1,826 @@ +# temporary minimum implementation of LoRA +# SD3 doesn't have Conv2d, so we ignore it +# TODO commonize with the original/SD3/FLUX implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from transformers import CLIPTextModelWithProjection, T5EncoderModel +import numpy as np +import torch +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from networks.lora_flux import LoRAModule, LoRAInfModule +from library import sd3_models + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: sd3_models.SDVAE, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + mmdit, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv + context_attn_dim = kwargs.get("context_attn_dim", None) + context_mlp_dim = kwargs.get("context_mlp_dim", None) + context_mod_dim = kwargs.get("context_mod_dim", None) + x_attn_dim = kwargs.get("x_attn_dim", None) + x_mlp_dim = kwargs.get("x_mlp_dim", None) + x_mod_dim = kwargs.get("x_mod_dim", None) + if context_attn_dim is not None: + context_attn_dim = int(context_attn_dim) + if context_mlp_dim is not None: + context_mlp_dim = int(context_mlp_dim) + if context_mod_dim is not None: + context_mod_dim = int(context_mod_dim) + if x_attn_dim is not None: + x_attn_dim = int(x_attn_dim) + if x_mlp_dim is not None: + x_mlp_dim = int(x_mlp_dim) + if x_mod_dim is not None: + x_mod_dim = int(x_mod_dim) + type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear] + emb_dims = kwargs.get("emb_dims", None) + if emb_dims is not None: + emb_dims = emb_dims.strip() + if emb_dims.startswith("[") and emb_dims.endswith("]"): + emb_dims = emb_dims[1:-1] + emb_dims = [int(d) for d in emb_dims.split(",")] # is it better to use ast.literal_eval? + assert len(emb_dims) == 6, f"invalid emb_dims: {emb_dims}, must be 6 dimensions (context, t, x, y, final_mod, final_linear)" + + # double/single train blocks + def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: + """ + Parse a block selection string and return a list of booleans. + + Args: + selection (str): A string specifying which blocks to select. + total_blocks (int): The total number of blocks available. + + Returns: + List[bool]: A list of booleans indicating which blocks are selected. + """ + if selection == "all": + return [True] * total_blocks + if selection == "none" or selection == "": + return [False] * total_blocks + + selected = [False] * total_blocks + ranges = selection.split(",") + + for r in ranges: + if "-" in r: + start, end = map(str.strip, r.split("-")) + start = int(start) + end = int(end) + assert 0 <= start < total_blocks, f"invalid start index: {start}" + assert 0 <= end < total_blocks, f"invalid end index: {end}" + assert start <= end, f"invalid range: {start}-{end}" + for i in range(start, end + 1): + selected[i] = True + else: + index = int(r) + assert 0 <= index < total_blocks, f"invalid index: {index}" + selected[index] = True + + return selected + + train_block_indices = kwargs.get("train_block_indices", None) + if train_block_indices is not None: + train_block_indices = parse_block_selection(train_block_indices, 999) # 999 is a dummy number + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + + # train T5XXL + train_t5xxl = kwargs.get("train_t5xxl", False) + if train_t5xxl is not None: + train_t5xxl = True if train_t5xxl == "True" else False + + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + mmdit, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + type_dims=type_dims, + emb_dims=emb_dims, + train_block_indices=train_block_indices, + verbose=verbose, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, mmdit, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping, and train t5xxl + modules_dim = {} + modules_alpha = {} + train_t5xxl = None + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + if train_t5xxl is None or train_t5xxl is False: + train_t5xxl = "lora_te3" in lora_name + + if train_t5xxl is None: + train_t5xxl = False + + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoders, + mmdit, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + ) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + SD3_TARGET_REPLACE_MODULE = ["SingleDiTBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"] + LORA_PREFIX_SD3 = "lora_unet" # make ComfyUI compatible + LORA_PREFIX_TEXT_ENCODER_CLIP_L = "lora_te1" + LORA_PREFIX_TEXT_ENCODER_CLIP_G = "lora_te2" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible + + def __init__( + self, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + unet: sd3_models.MMDiT, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + split_qkv: bool = False, + train_t5xxl: bool = False, + type_dims: Optional[List[int]] = None, + emb_dims: Optional[List[int]] = None, + train_block_indices: Optional[List[bool]] = None, + verbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.split_qkv = split_qkv + self.train_t5xxl = train_t5xxl + + self.type_dims = type_dims + self.emb_dims = emb_dims + self.train_block_indices = train_block_indices + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + self.emb_dims = [0] * 6 # create emb_dims + # verbose = True + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + + qkv_dim = 0 + if self.split_qkv: + logger.info(f"split qkv for LoRA") + qkv_dim = unet.joint_blocks[0].context_block.attn.qkv.weight.size(0) + if train_t5xxl: + logger.info(f"train T5XXL as well") + + # create module instances + def create_modules( + is_mmdit: bool, + text_encoder_idx: Optional[int], + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_SD3 + if is_mmdit + else [self.LORA_PREFIX_TEXT_ENCODER_CLIP_L, self.LORA_PREFIX_TEXT_ENCODER_CLIP_G, self.LORA_PREFIX_TEXT_ENCODER_T5][ + text_encoder_idx + ] + ) + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # dirty hack for all modules + module = root_module # search all modules + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + (name + "." if name else "") + child_name + lora_name = lora_name.replace(".", "_") + + if filter is not None and not filter in lora_name: + continue + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if is_mmdit and type_dims is not None: + # type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim] + identifier = [ + ("context_block", "attn"), + ("context_block", "mlp"), + ("context_block", "adaLN_modulation"), + ("x_block", "attn"), + ("x_block", "mlp"), + ("x_block", "adaLN_modulation"), + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d # may be 0 for skip + break + + if is_mmdit and dim and self.train_block_indices is not None and "joint_blocks" in lora_name: + # "lora_unet_joint_blocks_0_x_block_attn_proj..." + block_index = int(lora_name.split("_")[4]) # bit dirty + if self.train_block_indices is not None and not self.train_block_indices[block_index]: + dim = 0 + + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + # qkv split + split_dims = None + if is_mmdit and split_qkv: + if "joint_blocks" in lora_name and "qkv" in lora_name: + split_dims = [qkv_dim // 3] * 3 + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + split_dims=split_dims, + ) + loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched + return loras, skipped + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + index = i + if not train_t5xxl and index >= 2: # 0: CLIP-L, 1: CLIP-G, 2: T5XXL, so we skip T5XXL if train_t5xxl is False + break + + logger.info(f"create LoRA for Text Encoder {index+1}:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + + # create LoRA for U-Net + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.SD3_TARGET_REPLACE_MODULE) + + # emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear] + if self.emb_dims: + for filter, in_dim in zip( + [ + "context_embedder", + "t_embedder", + "x_embedder", + "y_embedder", + "final_layer_adaLN_modulation", + "final_layer_linear", + ], + self.emb_dims, + ): + loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim) + self.unet_loras.extend(loras) + + logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_te + skipped_un + if verbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to split qkv + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # split qkv + for key in list(state_dict.keys()): + if not ("joint_blocks" in key and "qkv" in key): + continue + + weight = state_dict[key] + lora_name = key.split(".")[0] + if "lora_down" in key and "weight" in key: + # dense weight (rank*3, in_dim) + split_weight = torch.chunk(weight, 3, dim=0) + for i, split_w in enumerate(split_weight): + state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w + + del state_dict[key] + # print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}") + elif "lora_up" in key and "weight" in key: + # sparse weight (out_dim=sum(split_dims), rank*3) + rank = weight.size(1) // 3 + i = 0 + split_dim = weight.shape[0] // 3 + for j in range(3): + state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dim, j * rank : (j + 1) * rank] + i += split_dim + del state_dict[key] + + # alpha is unchanged + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if not ("joint_blocks" in key and "qkv" in key): + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(3)] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(3)] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + qkv_dim, rank = up_weights[0].size() + split_dim = qkv_dim // 3 + up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(3): + up_weight[i : i + split_dim, j * rank : (j + 1) * rank] = up_weights[j] + i += split_dim + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + + def apply_to(self, text_encoders, mmdit, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, mmdit, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if ( + key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_L) + or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_G) + or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5) + ): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_MMDIT): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of three elements + # if float, use the same value for all three + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr, default_lr, default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr), float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0], text_encoder_lr[0]] + elif len(text_encoder_lr) == 2: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[1], text_encoder_lr[1]] + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if loraplus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [ + lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_L) + ] + te2_loras = [ + lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_G) + ] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te2_loras) > 0: + logger.info(f"Text Encoder 2 (CLIP-G): {len(te2_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te2_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 3 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[2]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[2], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 3 " + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/sd3_train.py b/sd3_train.py index 5e2efa6f8..d12f7f56b 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -220,12 +220,7 @@ def train(args): sd3_state_dict = None # load tokenizer and prepare tokenize strategy - if args.t5xxl_max_token_length is None: - t5xxl_max_token_length = 256 # default value for T5XXL - else: - t5xxl_max_token_length = args.t5xxl_max_token_length - - sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(t5xxl_max_token_length) + sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length) strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy) # load clip_l, clip_g, t5xxl for caching text encoder outputs @@ -876,6 +871,9 @@ def optimizer_hook(parameter: torch.Tensor): lg_out = None t5_out = None lg_pooled = None + l_attn_mask = None + g_attn_mask = None + t5_attn_mask = None if lg_out is None: # not cached or training, so get from text encoders @@ -885,7 +883,7 @@ def optimizer_hook(parameter: torch.Tensor): # text models in sd3_models require "cpu" for input_ids input_ids_clip_l = input_ids_clip_l.to("cpu") input_ids_clip_g = input_ids_clip_g.to("cpu") - lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens( + lg_out, _, lg_pooled, l_attn_mask, g_attn_mask, _ = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [clip_l, clip_g, None], [input_ids_clip_l, input_ids_clip_g, None, l_attn_mask, g_attn_mask, None], @@ -895,7 +893,7 @@ def optimizer_hook(parameter: torch.Tensor): _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] with torch.no_grad(): input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None - _, t5_out, _ = text_encoding_strategy.encode_tokens( + _, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] ) @@ -1104,22 +1102,6 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする" ) - parser.add_argument( - "--t5xxl_max_token_length", - type=int, - default=None, - help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256", - ) - parser.add_argument( - "--apply_lg_attn_mask", - action="store_true", - help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する", - ) - parser.add_argument( - "--apply_t5_attn_mask", - action="store_true", - help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", - ) parser.add_argument( "--learning_rate_te1", diff --git a/sd3_train_network.py b/sd3_train_network.py new file mode 100644 index 000000000..0f4ca93ef --- /dev/null +++ b/sd3_train_network.py @@ -0,0 +1,427 @@ +import argparse +import copy +import math +import random +from typing import Any, Optional + +import torch +from accelerate import Accelerator +from library import strategy_sd3, utils +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3, train_util +import train_network +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class Sd3NetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + self.is_schnell: Optional[bool] = None + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for SD3 + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # prepare CLIP-L/CLIP-G/T5XXL training flags + self.train_clip = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future + state_dict = utils.load_safetensors( + args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype + ) + mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu") + self.model_type = mmdit.model_type + + if args.fp8_base: + # check dtype of model + if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {mmdit.dtype}") + elif mmdit.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 SD3 model") + + clip_l = sd3_utils.load_clip_l( + args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + clip_l.eval() + clip_g = sd3_utils.load_clip_g( + args.clip_g, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + clip_g.eval() + + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = sd3_utils.load_t5xxl( + args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + vae = sd3_utils.load_vae( + args.vae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + + return mmdit.model_type, [clip_l, clip_g, t5xxl], vae, mmdit + + def get_tokenize_strategy(self, args): + logger.info(f"t5xxl_max_token_length: {args.t5xxl_max_token_length}") + return strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.clip_g, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sd3.Sd3TextEncodingStrategy(args.apply_lg_attn_mask, args.apply_t5_attn_mask) + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + if args.cache_text_encoder_outputs: + if self.train_clip and not self.train_t5xxl: + return text_encoders[0:2] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached + else: + return None # no text encoders are needed for encoding because both are cached + else: + return text_encoders # CLIP-L, CLIP-G and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_clip, self.train_clip, self.train_t5xxl] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_clip or self.train_t5xxl, + apply_lg_attn_mask=args.apply_lg_attn_mask, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[2].to(accelerator.device) # may be fp8 + + if text_encoders[2].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(2, text_encoders[2], text_encoders[2].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[2].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, + text_encoders, + tokens_and_masks, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move CLIP-G back to cpu") + text_encoders[1].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[2].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + text_encoders[2].to(accelerator.device) + + # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # # get size embeddings + # orig_size = batch["original_sizes_hw"] + # crop_size = batch["crop_top_lefts"] + # target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # # concat embeddings + # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + # return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, mmdit): + text_encoders = text_encoder # for compatibility + text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + + sd3_train_utils.sample_images( + accelerator, args, epoch, global_step, mmdit, vae, text_encoders, self.sample_prompts_te_outputs + ) + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + # shift 3.0 is the default value + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images) + + def shift_scale_latents(self, args, latents): + return latents + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + ): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( + args, self.noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t.dtype.is_floating_point: + t.requires_grad_(True) + + # Predict the noise residual + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds + text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) + if not args.apply_lg_attn_mask: + l_attn_mask = None + g_attn_mask = None + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + # call model + with accelerator.autocast(): + # TODO support attention mask + model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) + + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + model_pred_prior = unet( + noisy_model_input[diff_output_pr_indices], + timesteps[diff_output_pr_indices], + context=context[diff_output_pr_indices], + y=lg_pooled[diff_output_pr_indices], + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = model_pred_prior * (-sigmas[diff_output_pr_indices]) + noisy_model_input[diff_output_pr_indices] + + # weighting for differential output preservation is not needed because it is already applied + + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, None, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, sd3=self.model_type) + + def update_metadata(self, metadata, args): + metadata["ss_apply_lg_attn_mask"] = args.apply_lg_attn_mask + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0 or index == 1: # CLIP-L/CLIP-G + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0 or index == 1: # CLIP-L/CLIP-G + clip_type = "CLIP-L" if index == 0 else "CLIP-G" + logger.info(f"prepare CLIP-{clip_type} for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + sd3_train_utils.add_sd3_training_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = Sd3NetworkTrainer() + trainer.train(args) diff --git a/train_network.py b/train_network.py index 9943b60bd..aab1d84be 100644 --- a/train_network.py +++ b/train_network.py @@ -129,6 +129,7 @@ def get_text_encoder_outputs_caching_strategy(self, args): def get_models_for_text_encoding(self, args, accelerator, text_encoders): """ Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models. + FLUX.1 and SD3 may cache some outputs of the text encoder, so return the models that will be used for encoding (not cached). """ return text_encoders @@ -591,6 +592,7 @@ def train(self, args): # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory + logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above unet.requires_grad_(False) From 0031d916f0fa035d5d48a25fcabadc149bfbb639 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Fri, 25 Oct 2024 23:20:38 +0900 Subject: [PATCH 194/348] add latent scaling/shifting --- sd3_train_network.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sd3_train_network.py b/sd3_train_network.py index 0f4ca93ef..ecacf16cc 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -6,7 +6,7 @@ import torch from accelerate import Accelerator -from library import strategy_sd3, utils +from library import sd3_models, strategy_sd3, utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -25,7 +25,6 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() self.sample_prompts_te_outputs = None - self.is_schnell: Optional[bool] = None def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -268,7 +267,7 @@ def encode_images_to_latents(self, args, accelerator, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): - return latents + return sd3_models.SDVAE.process_in(latents) def get_noise_pred_and_target( self, From 56bf7611644402996072bd8f909cf828ec7b27cc Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 26 Oct 2024 17:29:24 +0900 Subject: [PATCH 195/348] fix errors in SD3 LoRA training with Text Encoders close #1724 --- library/strategy_sd3.py | 26 +++++++++++++------------- sd3_train_network.py | 2 +- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index dd08cf004..a27e99e63 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -68,9 +68,9 @@ def encode_tokens( returned embeddings are not masked """ clip_l, clip_g, t5xxl = models - clip_l: CLIPTextModel - clip_g: CLIPTextModelWithProjection - t5xxl: T5EncoderModel + clip_l: Optional[CLIPTextModel] + clip_g: Optional[CLIPTextModelWithProjection] + t5xxl: Optional[T5EncoderModel] if apply_lg_attn_mask is None: apply_lg_attn_mask = self.apply_lg_attn_mask @@ -84,25 +84,23 @@ def encode_tokens( if not apply_lg_attn_mask: l_attn_mask = None g_attn_mask = None - else: - l_attn_mask = l_attn_mask.to(clip_l.device) - g_attn_mask = g_attn_mask.to(clip_g.device) if not apply_t5_attn_mask: t5_attn_mask = None - else: - t5_attn_mask = t5_attn_mask.to(t5xxl.device) else: l_attn_mask = None g_attn_mask = None t5_attn_mask = None - if l_tokens is None: + if l_tokens is None or clip_l is None: assert g_tokens is None, "g_tokens must be None if l_tokens is None" lg_out = None lg_pooled = None else: with torch.no_grad(): assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None + g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None + prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True) l_pooled = prompt_embeds[0] l_out = prompt_embeds.hidden_states[-2] @@ -114,13 +112,15 @@ def encode_tokens( lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None lg_out = torch.cat([l_out, g_out], dim=-1) - if t5xxl is not None and t5_tokens is not None: + if t5xxl is None or t5_tokens is None: + t5_out = None + else: + t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None with torch.no_grad(): t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True) - else: - t5_out = None - return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] # masks are used for attention masking in transformer + # masks are used for attention masking in transformer + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] def concat_encodings( self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor diff --git a/sd3_train_network.py b/sd3_train_network.py index ecacf16cc..129afed54 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -134,7 +134,7 @@ def post_process_network(self, args, accelerator, network, text_encoders, unet): def get_models_for_text_encoding(self, args, accelerator, text_encoders): if args.cache_text_encoder_outputs: if self.train_clip and not self.train_t5xxl: - return text_encoders[0:2] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached + return text_encoders[0:2] + [None] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached else: return None # no text encoders are needed for encoding because both are cached else: From 014064fd8186420abf5dfc7c99ad0b39fee33f8a Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 26 Oct 2024 18:59:45 +0900 Subject: [PATCH 196/348] fix sample image generation without seed failed close #1726 --- library/sd3_train_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index e3c649f73..b04b86fb3 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -316,6 +316,8 @@ def do_sample( # noise = get_noise(seed, latent).to(device) if seed is not None: generator = torch.manual_seed(seed) + else: + generator = None noise = ( torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu") .to(latent.dtype) From db2b4d41b9637cffd40a694c8e25847446a57aad Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 27 Oct 2024 16:42:58 +0900 Subject: [PATCH 197/348] Add dropout rate arguments for CLIP-L, CLIP-G, and T5, fix Text Encoders LoRA not trained --- library/sd3_train_utils.py | 18 ++++++++ library/strategy_sd3.py | 93 ++++++++++++++++++++++++++++++++++---- sd3_train.py | 15 ++++-- sd3_train_network.py | 16 ++++++- train_network.py | 13 ++++-- 5 files changed, 138 insertions(+), 17 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index b04b86fb3..a0202ad40 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -214,6 +214,24 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): action="store_true", help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", ) + parser.add_argument( + "--clip_l_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--clip_g_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--t5_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0", + ) # copy from Diffusers parser.add_argument( diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index a27e99e63..d87ad7d15 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -1,5 +1,6 @@ import os import glob +import random from typing import Any, List, Optional, Tuple, Union import torch import numpy as np @@ -48,13 +49,23 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: class Sd3TextEncodingStrategy(TextEncodingStrategy): - def __init__(self, apply_lg_attn_mask: Optional[bool] = None, apply_t5_attn_mask: Optional[bool] = None) -> None: + def __init__( + self, + apply_lg_attn_mask: Optional[bool] = None, + apply_t5_attn_mask: Optional[bool] = None, + l_dropout_rate: float = 0.0, + g_dropout_rate: float = 0.0, + t5_dropout_rate: float = 0.0, + ) -> None: """ Args: apply_t5_attn_mask: Default value for apply_t5_attn_mask. """ self.apply_lg_attn_mask = apply_lg_attn_mask self.apply_t5_attn_mask = apply_t5_attn_mask + self.l_dropout_rate = l_dropout_rate + self.g_dropout_rate = g_dropout_rate + self.t5_dropout_rate = t5_dropout_rate def encode_tokens( self, @@ -63,6 +74,7 @@ def encode_tokens( tokens: List[torch.Tensor], apply_lg_attn_mask: Optional[bool] = False, apply_t5_attn_mask: Optional[bool] = False, + enable_dropout: bool = True, ) -> List[torch.Tensor]: """ returned embeddings are not masked @@ -91,37 +103,92 @@ def encode_tokens( g_attn_mask = None t5_attn_mask = None + # dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings + if l_tokens is None or clip_l is None: assert g_tokens is None, "g_tokens must be None if l_tokens is None" lg_out = None lg_pooled = None else: - with torch.no_grad(): - assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + + drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) + if drop_l: + l_pooled = torch.zeros((l_tokens.shape[0], 768), device=l_tokens.device, dtype=l_tokens.dtype) + l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=l_tokens.device, dtype=l_tokens.dtype) + if l_attn_mask is not None: + l_attn_mask = torch.zeros_like(l_attn_mask) + else: l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None - g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None - prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True) l_pooled = prompt_embeds[0] l_out = prompt_embeds.hidden_states[-2] + drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) + if drop_g: + g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=g_tokens.device, dtype=g_tokens.dtype) + g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=g_tokens.device, dtype=g_tokens.dtype) + if g_attn_mask is not None: + g_attn_mask = torch.zeros_like(g_attn_mask) + else: + g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True) g_pooled = prompt_embeds[0] g_out = prompt_embeds.hidden_states[-2] - lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None - lg_out = torch.cat([l_out, g_out], dim=-1) + lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None + lg_out = torch.cat([l_out, g_out], dim=-1) if t5xxl is None or t5_tokens is None: t5_out = None else: - t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None - with torch.no_grad(): + drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) + if drop_t5: + t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5_tokens.device, dtype=t5_tokens.dtype) + if t5_attn_mask is not None: + t5_attn_mask = torch.zeros_like(t5_attn_mask) + else: + t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True) # masks are used for attention masking in transformer return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] + def drop_cached_text_encoder_outputs( + self, + lg_out: torch.Tensor, + t5_out: torch.Tensor, + lg_pooled: torch.Tensor, + l_attn_mask: torch.Tensor, + g_attn_mask: torch.Tensor, + t5_attn_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings + if lg_out is not None: + for i in range(lg_out.shape[0]): + drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate + if drop_l: + lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768]) + lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768]) + if l_attn_mask is not None: + l_attn_mask[i] = torch.zeros_like(l_attn_mask[i]) + drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate + if drop_g: + lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:]) + lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:]) + if g_attn_mask is not None: + g_attn_mask[i] = torch.zeros_like(g_attn_mask[i]) + + if t5_out is not None: + for i in range(t5_out.shape[0]): + drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate + if drop_t5: + t5_out[i] = torch.zeros_like(t5_out[i]) + if t5_attn_mask is not None: + t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) + + return lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask + def concat_encodings( self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -207,8 +274,14 @@ def cache_batch_outputs( tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): + # always disable dropout during caching lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask + tokenize_strategy, + models, + tokens_and_masks, + apply_lg_attn_mask=self.apply_lg_attn_mask, + apply_t5_attn_mask=self.apply_t5_attn_mask, + enable_dropout=False, ) if lg_out.dtype == torch.bfloat16: diff --git a/sd3_train.py b/sd3_train.py index d12f7f56b..cdac945e6 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -69,6 +69,11 @@ def train(args): # assert ( # not args.train_text_encoder or not args.cache_text_encoder_outputs # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), ( "when training text encoder, text encoder outputs must not be cached (except for T5XXL)" @@ -232,7 +237,9 @@ def train(args): assert clip_l is not None and clip_g is not None and t5xxl is not None, "clip_l, clip_g, t5xxl must be specified" # prepare text encoding strategy - text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy(args.apply_lg_attn_mask, args.apply_t5_attn_mask) + text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy( + args.apply_lg_attn_mask, args.apply_t5_attn_mask, args.clip_l_dropout_rate, args.clip_g_dropout_rate, args.t5_dropout_rate + ) strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) # 学習を準備する:モデルを適切な状態にする @@ -311,6 +318,7 @@ def train(args): tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask, + enable_dropout=False, ) accelerator.wait_for_everyone() @@ -863,6 +871,7 @@ def optimizer_hook(parameter: torch.Tensor): text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: + text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list if args.use_t5xxl_cache_only: lg_out = None @@ -878,7 +887,7 @@ def optimizer_hook(parameter: torch.Tensor): if lg_out is None: # not cached or training, so get from text encoders input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] - with torch.set_grad_enabled(args.train_text_encoder): + with torch.set_grad_enabled(train_clip): # TODO support weighted captions # text models in sd3_models require "cpu" for input_ids input_ids_clip_l = input_ids_clip_l.to("cpu") @@ -891,7 +900,7 @@ def optimizer_hook(parameter: torch.Tensor): if t5_out is None: _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] - with torch.no_grad(): + with torch.set_grad_enabled(train_t5xxl): input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None _, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] diff --git a/sd3_train_network.py b/sd3_train_network.py index 129afed54..7b5471274 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -120,7 +120,13 @@ def get_latents_caching_strategy(self, args): return latents_caching_strategy def get_text_encoding_strategy(self, args): - return strategy_sd3.Sd3TextEncodingStrategy(args.apply_lg_attn_mask, args.apply_t5_attn_mask) + return strategy_sd3.Sd3TextEncodingStrategy( + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + args.clip_l_dropout_rate, + args.clip_g_dropout_rate, + args.t5xxl_dropout_rate, + ) def post_process_network(self, args, accelerator, network, text_encoders, unet): # check t5xxl is trained or not @@ -408,6 +414,14 @@ def forward(hidden_states): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + # drop cached text encoder outputs + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(text_encoder_outputs_list) + batch["text_encoder_outputs_list"] = text_encoder_outputs_list + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() diff --git a/train_network.py b/train_network.py index aab1d84be..9d78a4ef2 100644 --- a/train_network.py +++ b/train_network.py @@ -272,6 +272,9 @@ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): text_encoder.text_model.embeddings.to(dtype=weight_dtype) + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + pass + # endregion def train(self, args): @@ -1030,9 +1033,9 @@ def load_model_hook(models, input_dir): # callback for step start if hasattr(accelerator.unwrap_model(network), "on_step_start"): - on_step_start = accelerator.unwrap_model(network).on_step_start + on_step_start_for_network = accelerator.unwrap_model(network).on_step_start else: - on_step_start = lambda *args, **kwargs: None + on_step_start_for_network = lambda *args, **kwargs: None # function for saving/removing def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): @@ -1113,7 +1116,10 @@ def remove_model(old_ckpt_name): continue with accelerator.accumulate(training_model): - on_step_start(text_encoder, unet) + on_step_start_for_network(text_encoder, unet) + + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) @@ -1146,6 +1152,7 @@ def remove_model(old_ckpt_name): if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: + # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: From a1255d637f545b0d6defebf080ca31f2370bf311 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 27 Oct 2024 17:03:36 +0900 Subject: [PATCH 198/348] Fix SD3 LoRA training to work (WIP) --- library/strategy_sd3.py | 20 ++++++++++---------- sd3_train_network.py | 15 ++++++++------- train_network.py | 20 ++++++++++++++++++++ 3 files changed, 38 insertions(+), 17 deletions(-) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index d87ad7d15..e57bb337e 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -111,13 +111,13 @@ def encode_tokens( lg_pooled = None else: assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" - + drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) if drop_l: - l_pooled = torch.zeros((l_tokens.shape[0], 768), device=l_tokens.device, dtype=l_tokens.dtype) - l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=l_tokens.device, dtype=l_tokens.dtype) + l_pooled = torch.zeros((l_tokens.shape[0], 768), device=clip_l.device, dtype=clip_l.dtype) + l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=clip_l.device, dtype=clip_l.dtype) if l_attn_mask is not None: - l_attn_mask = torch.zeros_like(l_attn_mask) + l_attn_mask = torch.zeros_like(l_attn_mask, device=clip_l.device) else: l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True) @@ -126,10 +126,10 @@ def encode_tokens( drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) if drop_g: - g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=g_tokens.device, dtype=g_tokens.dtype) - g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=g_tokens.device, dtype=g_tokens.dtype) + g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=clip_g.device, dtype=clip_g.dtype) + g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=clip_g.device, dtype=clip_g.dtype) if g_attn_mask is not None: - g_attn_mask = torch.zeros_like(g_attn_mask) + g_attn_mask = torch.zeros_like(g_attn_mask, device=clip_g.device) else: g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True) @@ -144,9 +144,9 @@ def encode_tokens( else: drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) if drop_t5: - t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5_tokens.device, dtype=t5_tokens.dtype) + t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5xxl.device, dtype=t5xxl.dtype) if t5_attn_mask is not None: - t5_attn_mask = torch.zeros_like(t5_attn_mask) + t5_attn_mask = torch.zeros_like(t5_attn_mask, device=t5xxl.device) else: t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True) @@ -187,7 +187,7 @@ def drop_cached_text_encoder_outputs( if t5_attn_mask is not None: t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) - return lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] def concat_encodings( self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor diff --git a/sd3_train_network.py b/sd3_train_network.py index 7b5471274..620a336fd 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -125,7 +125,7 @@ def get_text_encoding_strategy(self, args): args.apply_t5_attn_mask, args.clip_l_dropout_rate, args.clip_g_dropout_rate, - args.t5xxl_dropout_rate, + args.t5_dropout_rate, ) def post_process_network(self, args, accelerator, network, text_encoders, unet): @@ -415,12 +415,13 @@ def forward(hidden_states): prepare_fp8(text_encoder, weight_dtype) def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): - # drop cached text encoder outputs - text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) - if text_encoder_outputs_list is not None: - text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(text_encoder_outputs_list) - batch["text_encoder_outputs_list"] = text_encoder_outputs_list + # # drop cached text encoder outputs + # text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + # if text_encoder_outputs_list is not None: + # text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + # text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) + # batch["text_encoder_outputs_list"] = text_encoder_outputs_list + pass def setup_parser() -> argparse.ArgumentParser: diff --git a/train_network.py b/train_network.py index 9d78a4ef2..76936b2ed 100644 --- a/train_network.py +++ b/train_network.py @@ -1151,6 +1151,17 @@ def remove_model(old_ckpt_name): text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + + # if text_encoder_outputs_list is not None: + # lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list + # for i in range(len(lg_out)): + # print( + # f"[{i}] cached L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, cached G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, " + # f"cached T5: {t5_out[i].max()}, " + # f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0}," + # f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}" + # ) + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): @@ -1182,6 +1193,15 @@ def remove_model(old_ckpt_name): if encoded_text_encoder_conds[i] is not None: text_encoder_conds[i] = encoded_text_encoder_conds[i] + # lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds + # for i in range(len(lg_out)): + # print( + # f"[{i}] train L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, train G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, " + # f"train T5: {t5_out[i].max()}, " + # f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0}," + # f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}" + # ) + # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( args, From d4f7849592c78455ddd268423528830ec5e55f47 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 27 Oct 2024 19:35:56 +0900 Subject: [PATCH 199/348] prevent unintended cast for disk cached TE outputs --- library/train_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index d3c59ef98..d568523ca 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1615,7 +1615,6 @@ def __getitem__(self, index): text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( image_info.text_encoder_outputs_npz ) - text_encoder_outputs = [torch.FloatTensor(x) for x in text_encoder_outputs] else: tokenization_required = True text_encoder_outputs_list.append(text_encoder_outputs) From 1065dd1b56b4b18e211d3827fe22b459c81dd12c Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 27 Oct 2024 19:36:36 +0900 Subject: [PATCH 200/348] Fix to work dropout_rate for TEs --- flux_train_network.py | 2 +- library/strategy_flux.py | 1 + library/strategy_sd3.py | 142 +++++++++++++++++++++++++++------------ sd3_train_network.py | 15 ++--- train_network.py | 19 ------ 5 files changed, 108 insertions(+), 71 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index cffeb3b19..2b71a8979 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -363,7 +363,7 @@ def get_noise_pred_and_target( if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) for t in text_encoder_conds: - if t.dtype.is_floating_point: + if t is not None and t.dtype.is_floating_point: t.requires_grad_(True) img_ids.requires_grad_(True) guidance_vec.requires_grad_(True) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 0b0c34af7..f662b62e9 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -190,6 +190,7 @@ def cache_batch_outputs( apply_t5_attn_mask=apply_t5_attn_mask_i, ) else: + # it's fine that attn mask is not None. it's overwritten before calling the model if necessary info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index e57bb337e..413169ecc 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -89,19 +89,7 @@ def encode_tokens( if apply_t5_attn_mask is None: apply_t5_attn_mask = self.apply_t5_attn_mask - l_tokens, g_tokens, t5_tokens = tokens[:3] - - if len(tokens) > 3: - l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] - if not apply_lg_attn_mask: - l_attn_mask = None - g_attn_mask = None - if not apply_t5_attn_mask: - t5_attn_mask = None - else: - l_attn_mask = None - g_attn_mask = None - t5_attn_mask = None + l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens # dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings @@ -109,47 +97,114 @@ def encode_tokens( assert g_tokens is None, "g_tokens must be None if l_tokens is None" lg_out = None lg_pooled = None + l_attn_mask = None + g_attn_mask = None else: assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" - drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) - if drop_l: - l_pooled = torch.zeros((l_tokens.shape[0], 768), device=clip_l.device, dtype=clip_l.dtype) - l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=clip_l.device, dtype=clip_l.dtype) - if l_attn_mask is not None: - l_attn_mask = torch.zeros_like(l_attn_mask, device=clip_l.device) + # drop some members of the batch: we do not call clip_l and clip_g for dropped members + batch_size, l_seq_len = l_tokens.shape + g_seq_len = g_tokens.shape[1] + + non_drop_l_indices = [] + non_drop_g_indices = [] + for i in range(l_tokens.shape[0]): + drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) + drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) + if not drop_l: + non_drop_l_indices.append(i) + if not drop_g: + non_drop_g_indices.append(i) + + # filter out dropped members + if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size: + l_tokens = l_tokens[non_drop_l_indices] + l_attn_mask = l_attn_mask[non_drop_l_indices] + if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size: + g_tokens = g_tokens[non_drop_g_indices] + g_attn_mask = g_attn_mask[non_drop_g_indices] + + # call clip_l for non-dropped members + if len(non_drop_l_indices) > 0: + nd_l_attn_mask = l_attn_mask.to(clip_l.device) + prompt_embeds = clip_l( + l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True + ) + nd_l_pooled = prompt_embeds[0] + nd_l_out = prompt_embeds.hidden_states[-2] + if len(non_drop_g_indices) > 0: + nd_g_attn_mask = g_attn_mask.to(clip_g.device) + prompt_embeds = clip_g( + g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True + ) + nd_g_pooled = prompt_embeds[0] + nd_g_out = prompt_embeds.hidden_states[-2] + + # fill in the dropped members + if len(non_drop_l_indices) == batch_size: + l_pooled = nd_l_pooled + l_out = nd_l_out else: - l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None - prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True) - l_pooled = prompt_embeds[0] - l_out = prompt_embeds.hidden_states[-2] - - drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) - if drop_g: - g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=clip_g.device, dtype=clip_g.dtype) - g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=clip_g.device, dtype=clip_g.dtype) - if g_attn_mask is not None: - g_attn_mask = torch.zeros_like(g_attn_mask, device=clip_g.device) + # model output is always float32 because of the models are wrapped with Accelerator + l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32) + l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32) + l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype) + if len(non_drop_l_indices) > 0: + l_pooled[non_drop_l_indices] = nd_l_pooled + l_out[non_drop_l_indices] = nd_l_out + l_attn_mask[non_drop_l_indices] = nd_l_attn_mask + + if len(non_drop_g_indices) == batch_size: + g_pooled = nd_g_pooled + g_out = nd_g_out else: - g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None - prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True) - g_pooled = prompt_embeds[0] - g_out = prompt_embeds.hidden_states[-2] - - lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None + g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32) + g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32) + g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype) + if len(non_drop_g_indices) > 0: + g_pooled[non_drop_g_indices] = nd_g_pooled + g_out[non_drop_g_indices] = nd_g_out + g_attn_mask[non_drop_g_indices] = nd_g_attn_mask + + lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) lg_out = torch.cat([l_out, g_out], dim=-1) if t5xxl is None or t5_tokens is None: t5_out = None + t5_attn_mask = None else: - drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) - if drop_t5: - t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5xxl.device, dtype=t5xxl.dtype) - if t5_attn_mask is not None: - t5_attn_mask = torch.zeros_like(t5_attn_mask, device=t5xxl.device) + # drop some members of the batch: we do not call t5xxl for dropped members + batch_size, t5_seq_len = t5_tokens.shape + non_drop_t5_indices = [] + for i in range(t5_tokens.shape[0]): + drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) + if not drop_t5: + non_drop_t5_indices.append(i) + + # filter out dropped members + if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size: + t5_tokens = t5_tokens[non_drop_t5_indices] + t5_attn_mask = t5_attn_mask[non_drop_t5_indices] + + # call t5xxl for non-dropped members + if len(non_drop_t5_indices) > 0: + nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device) + nd_t5_out, _ = t5xxl( + t5_tokens.to(t5xxl.device), + nd_t5_attn_mask if apply_t5_attn_mask else None, + return_dict=False, + output_hidden_states=True, + ) + + # fill in the dropped members + if len(non_drop_t5_indices) == batch_size: + t5_out = nd_t5_out else: - t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None - t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True) + t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32) + t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype) + if len(non_drop_t5_indices) > 0: + t5_out[non_drop_t5_indices] = nd_t5_out + t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask # masks are used for attention masking in transformer return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] @@ -322,6 +377,7 @@ def cache_batch_outputs( apply_t5_attn_mask=apply_t5_attn_mask, ) else: + # it's fine that attn mask is not None. it's overwritten before calling the model if necessary info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i) diff --git a/sd3_train_network.py b/sd3_train_network.py index 620a336fd..3506404ae 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -300,7 +300,7 @@ def get_noise_pred_and_target( if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) for t in text_encoder_conds: - if t.dtype.is_floating_point: + if t is not None and t.dtype.is_floating_point: t.requires_grad_(True) # Predict the noise residual @@ -415,13 +415,12 @@ def forward(hidden_states): prepare_fp8(text_encoder, weight_dtype) def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): - # # drop cached text encoder outputs - # text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) - # if text_encoder_outputs_list is not None: - # text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - # text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) - # batch["text_encoder_outputs_list"] = text_encoder_outputs_list - pass + # drop cached text encoder outputs + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) + batch["text_encoder_outputs_list"] = text_encoder_outputs_list def setup_parser() -> argparse.ArgumentParser: diff --git a/train_network.py b/train_network.py index 76936b2ed..b90aa420e 100644 --- a/train_network.py +++ b/train_network.py @@ -1151,16 +1151,6 @@ def remove_model(old_ckpt_name): text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - - # if text_encoder_outputs_list is not None: - # lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list - # for i in range(len(lg_out)): - # print( - # f"[{i}] cached L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, cached G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, " - # f"cached T5: {t5_out[i].max()}, " - # f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0}," - # f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}" - # ) if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' @@ -1193,15 +1183,6 @@ def remove_model(old_ckpt_name): if encoded_text_encoder_conds[i] is not None: text_encoder_conds[i] = encoded_text_encoder_conds[i] - # lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds - # for i in range(len(lg_out)): - # print( - # f"[{i}] train L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, train G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, " - # f"train T5: {t5_out[i].max()}, " - # f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0}," - # f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}" - # ) - # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( args, From af8e216035128767234163a24debf2f4df5aa36d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 28 Oct 2024 22:08:57 +0900 Subject: [PATCH 201/348] Fix sample image gen to work with block swap --- library/sd3_train_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index a0202ad40..054d1b4a1 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -364,6 +364,7 @@ def do_sample( x_c_nc = torch.cat([x, x], dim=0) # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + mmdit.prepare_block_swap_before_forward() model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) model_output = model_output.float() batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) @@ -385,6 +386,7 @@ def do_sample( x = x + d * dt x = x.to(dtype) + mmdit.prepare_block_swap_before_forward() return x From 75554867ce390ec0957cc52a70c0695e19c71fe2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 29 Oct 2024 08:34:31 +0900 Subject: [PATCH 202/348] Fix error on saving T5XXL --- library/sd3_train_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 054d1b4a1..1702e81c2 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -75,7 +75,14 @@ def update_sd(prefix, sd): save_file(clip_g.state_dict(), clip_g_path) if t5xxl is not None: t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors") - save_file(t5xxl.state_dict(), t5xxl_path) + t5xxl_state_dict = t5xxl.state_dict() + + # replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file + shared_weight = t5xxl_state_dict["shared.weight"] + shared_weight_copy = shared_weight.detach().clone() + t5xxl_state_dict["shared.weight"] = shared_weight_copy + + save_file(t5xxl_state_dict, t5xxl_path) def save_sd3_model_on_train_end( From 0af4edd8a63d7fcdf02bdcbd11b8770fd1cae162 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 29 Oct 2024 21:51:56 +0900 Subject: [PATCH 203/348] Fix split_qkv --- networks/lora_sd3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py index cbabf8da0..249298b39 100644 --- a/networks/lora_sd3.py +++ b/networks/lora_sd3.py @@ -540,8 +540,8 @@ def state_dict(self, destination=None, prefix="", keep_vars=False): down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) # merge up weight (sum of split_dim, rank*3) - qkv_dim, rank = up_weights[0].size() - split_dim = qkv_dim // 3 + split_dim, rank = up_weights[0].size() + qkv_dim = split_dim * 3 up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) i = 0 for j in range(3): From d4e19fbd5e34e90347f189a8ba1f77e8878fe0ca Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 29 Oct 2024 21:52:04 +0900 Subject: [PATCH 204/348] Support Lora --- sd3_minimal_inference.py | 60 +++++++++++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index d099fe18d..86dba246d 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -10,11 +10,13 @@ import torch from safetensors.torch import safe_open, load_file +import torch.amp from tqdm import tqdm from PIL import Image from transformers import CLIPTextModelWithProjection, T5EncoderModel from library.device_utils import init_ipex, get_preferred_device +from networks import lora_sd3 init_ipex() @@ -104,7 +106,8 @@ def do_sample( x_c_nc = torch.cat([x, x], dim=0) # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) - model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + with torch.autocast(device_type=device.type, dtype=dtype): + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) model_output = model_output.float() batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) @@ -153,7 +156,7 @@ def generate_image( clip_g.to(device) t5xxl.to(device) - with torch.no_grad(): + with torch.autocast(device_type=device.type, dtype=mmdit.dtype), torch.no_grad(): tokens_and_masks = tokenize_strategy.tokenize(prompt) lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask @@ -233,13 +236,14 @@ def generate_image( parser.add_argument("--bf16", action="store_true") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--steps", type=int, default=50) - # parser.add_argument( - # "--lora_weights", - # type=str, - # nargs="*", - # default=[], - # help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)", - # ) + parser.add_argument( + "--lora_weights", + type=str, + nargs="*", + default=[], + help="LoRA weights, only supports networks.lora_sd3, each argument is a `path;multiplier` (semi-colon separated)", + ) + parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") parser.add_argument("--width", type=int, default=target_width) parser.add_argument("--height", type=int, default=target_height) parser.add_argument("--interactive", action="store_true") @@ -294,6 +298,30 @@ def generate_image( tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() + # LoRA + lora_models: list[lora_sd3.LoRANetwork] = [] + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + weights_sd = load_file(weights_file) + module = lora_sd3 + lora_model, _ = module.create_network_from_weights(multiplier, None, vae, [clip_l, clip_g, t5xxl], mmdit, weights_sd, True) + + if args.merge_lora_weights: + lora_model.merge_to([clip_l, clip_g, t5xxl], mmdit, weights_sd) + else: + lora_model.apply_to([clip_l, clip_g, t5xxl], mmdit) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.eval() + lora_model.to(device) + + lora_models.append(lora_model) + if not args.interactive: generate_image( mmdit, @@ -344,13 +372,13 @@ def generate_image( steps = int(opt[1:].strip()) elif opt.startswith("d"): seed = int(opt[1:].strip()) - # elif opt.startswith("m"): - # mutipliers = opt[1:].strip().split(",") - # if len(mutipliers) != len(lora_models): - # logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") - # continue - # for i, lora_model in enumerate(lora_models): - # lora_model.set_multiplier(float(mutipliers[i])) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) elif opt.startswith("n"): negative_prompt = opt[1:].strip() if negative_prompt == "-": From 1e2f7b0e44ee656cd8d0ca8268aa1371618031ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 29 Oct 2024 22:11:04 +0900 Subject: [PATCH 205/348] Support for checkpoint files with a mysterious prefix "model.diffusion_model." --- library/flux_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/library/flux_utils.py b/library/flux_utils.py index 7a1ec37b8..4403835f1 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -73,6 +73,10 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int with safe_open(ckpt_path, framework="pt") as f: keys.extend(f.keys()) + # if the key has annoying prefix, remove it + if keys[0].startswith("model.diffusion_model."): + keys = [key.replace("model.diffusion_model.", "") for key in keys] + is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) @@ -141,6 +145,13 @@ def load_flow_model( sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) logger.info("Converted Diffusers to BFL") + # if the key has annoying prefix, remove it + for key in list(sd.keys()): + new_key = key.replace("model.diffusion_model.", "") + if new_key == key: + break # the model doesn't have annoying prefix + sd[new_key] = sd.pop(key) + info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Flux: {info}") return is_schnell, model From ce5b5325829538c03ff9ce80a79fe2c84ca5283c Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 29 Oct 2024 22:29:24 +0900 Subject: [PATCH 206/348] Fix additional LoRA to work --- networks/lora_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py index 249298b39..c1eb68b8a 100644 --- a/networks/lora_sd3.py +++ b/networks/lora_sd3.py @@ -428,7 +428,7 @@ def create_modules( for filter, in_dim in zip( [ "context_embedder", - "t_embedder", + "_t_embedder", # don't use "t_embedder" because it's used in "context_embedder" "x_embedder", "y_embedder", "final_layer_adaLN_modulation", From b502f584886fbf52f9a180981efe276ea8509de7 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 29 Oct 2024 23:29:50 +0900 Subject: [PATCH 207/348] Fix emb_dim to work. --- networks/lora_sd3.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py index c1eb68b8a..efe202451 100644 --- a/networks/lora_sd3.py +++ b/networks/lora_sd3.py @@ -307,6 +307,7 @@ def create_modules( target_replace_modules: List[str], filter: Optional[str] = None, default_dim: Optional[int] = None, + include_conv2d_if_filter: bool = False, ) -> List[LoRAModule]: prefix = ( self.LORA_PREFIX_SD3 @@ -332,8 +333,11 @@ def create_modules( lora_name = prefix + "." + (name + "." if name else "") + child_name lora_name = lora_name.replace(".", "_") - if filter is not None and not filter in lora_name: - continue + force_incl_conv2d = False + if filter is not None: + if not filter in lora_name: + continue + force_incl_conv2d = include_conv2d_if_filter dim = None alpha = None @@ -373,6 +377,10 @@ def create_modules( elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha + elif force_incl_conv2d: + # x_embedder + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha if dim is None or dim == 0: # skipした情報を出力 @@ -428,7 +436,7 @@ def create_modules( for filter, in_dim in zip( [ "context_embedder", - "_t_embedder", # don't use "t_embedder" because it's used in "context_embedder" + "_t_embedder", # don't use "t_embedder" because it's used in "context_embedder" "x_embedder", "y_embedder", "final_layer_adaLN_modulation", @@ -436,7 +444,12 @@ def create_modules( ], self.emb_dims, ): - loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim) + # x_embedder is conv2d, so we need to include it + loras, _ = create_modules( + True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder" + ) + # if len(loras) > 0: + # logger.info(f"create LoRA for {filter}: {len(loras)} modules.") self.unet_loras.extend(loras) logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.") From bdddc20d68a7441cccfcf0009528fdd59403b94a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 30 Oct 2024 12:51:49 +0900 Subject: [PATCH 208/348] support SD3.5M --- library/sd3_models.py | 128 +++++++++++++++++++++++-------------- library/sd3_train_utils.py | 7 ++ library/sd3_utils.py | 13 ++-- sd3_train.py | 8 +-- sd3_train_network.py | 1 + 5 files changed, 99 insertions(+), 58 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 5d09f74e8..840f91869 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -51,7 +51,7 @@ class SD3Params: pos_embed_max_size: int adm_in_channels: int qk_norm: Optional[str] - x_block_self_attn_layers: List[int] + x_block_self_attn_layers: list[int] context_embedder_in_features: int context_embedder_out_features: int model_type: str @@ -510,6 +510,7 @@ def __init__( scale_mod_only: bool = False, swiglu: bool = False, qk_norm: Optional[str] = None, + x_block_self_attn: bool = False, **block_kwargs, ): super().__init__() @@ -519,13 +520,14 @@ def __init__( self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) else: self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.attn = AttentionLinears( - dim=hidden_size, - num_heads=num_heads, - qkv_bias=qkv_bias, - pre_only=pre_only, - qk_norm=qk_norm, - ) + self.attn = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=pre_only, qk_norm=qk_norm) + + self.x_block_self_attn = x_block_self_attn + if self.x_block_self_attn: + assert not pre_only + assert not scale_mod_only + self.attn2 = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=False, qk_norm=qk_norm) + if not pre_only: if not rmsnorm: self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -546,7 +548,9 @@ def __init__( multiple_of=256, ) self.scale_mod_only = scale_mod_only - if not scale_mod_only: + if self.x_block_self_attn: + n_mods = 9 + elif not scale_mod_only: n_mods = 6 if not pre_only else 2 else: n_mods = 4 if not pre_only else 1 @@ -556,63 +560,64 @@ def __init__( def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: if not self.pre_only: if not self.scale_mod_only: - ( - shift_msa, - scale_msa, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - ) = self.adaLN_modulation( - c - ).chunk(6, dim=-1) + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(6, dim=-1) else: shift_msa = None shift_mlp = None - ( - scale_msa, - gate_msa, - scale_mlp, - gate_mlp, - ) = self.adaLN_modulation( - c - ).chunk(4, dim=-1) + (scale_msa, gate_msa, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(4, dim=-1) qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) - return qkv, ( - x, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - ) + return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp) else: if not self.scale_mod_only: - ( - shift_msa, - scale_msa, - ) = self.adaLN_modulation( - c - ).chunk(2, dim=-1) + (shift_msa, scale_msa) = self.adaLN_modulation(c).chunk(2, dim=-1) else: shift_msa = None scale_msa = self.adaLN_modulation(c) qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) return qkv, None + def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + assert self.x_block_self_attn + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2) = self.adaLN_modulation( + c + ).chunk(9, dim=1) + x_norm = self.norm1(x) + qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa)) + qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2)) + return qkv, qkv2, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2) + def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): assert not self.pre_only x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x + def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2, attn1_dropout: float = 0.0): + assert not self.pre_only + if attn1_dropout > 0.0: + # Use torch.bernoulli to implement dropout, only dropout the batch dimension + attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device)) + attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout + else: + attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + attn_ + attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2) + x = x + attn2_ + mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + x = x + mlp_ + return x + # JointBlock + block_mixing in mmdit.py class MMDiTBlock(nn.Module): def __init__(self, *args, **kwargs): super().__init__() pre_only = kwargs.pop("pre_only") + x_block_self_attn = kwargs.pop("x_block_self_attn") + self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs) - self.x_block = SingleDiTBlock(*args, pre_only=False, **kwargs) + self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs) + self.head_dim = self.x_block.attn.head_dim self.mode = self.x_block.attn_mode self.gradient_checkpointing = False @@ -622,7 +627,11 @@ def enable_gradient_checkpointing(self): def _forward(self, context, x, c): ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c) - x_qkv, x_intermediate = self.x_block.pre_attention(x, c) + + if self.x_block.x_block_self_attn: + x_qkv, x_qkv2, x_intermediates = self.x_block.pre_attention_x(x, c) + else: + x_qkv, x_intermediates = self.x_block.pre_attention(x, c) ctx_len = ctx_qkv[0].size(1) @@ -634,11 +643,18 @@ def _forward(self, context, x, c): ctx_attn_out = attn[:, :ctx_len] x_attn_out = attn[:, ctx_len:] - x = self.x_block.post_attention(x_attn_out, *x_intermediate) + if self.x_block.x_block_self_attn: + x_q2, x_k2, x_v2 = x_qkv2 + attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads) + x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates) + else: + x = self.x_block.post_attention(x_attn_out, *x_intermediates) + if not self.context_block.pre_only: context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate) else: context = None + return context, x def forward(self, *args, **kwargs): @@ -678,7 +694,9 @@ def __init__( pos_embed_max_size: Optional[int] = None, num_patches=None, qk_norm: Optional[str] = None, + x_block_self_attn_layers: Optional[list[int]] = [], qkv_bias: bool = True, + pos_emb_random_crop_rate: float = 0.0, model_type: str = "sd3m", ): super().__init__() @@ -691,6 +709,8 @@ def __init__( self.pos_embed_scaling_factor = pos_embed_scaling_factor self.pos_embed_offset = pos_embed_offset self.pos_embed_max_size = pos_embed_max_size + self.x_block_self_attn_layers = x_block_self_attn_layers + self.pos_emb_random_crop_rate = pos_emb_random_crop_rate self.gradient_checkpointing = use_checkpoint # hidden_size = default(hidden_size, 64 * depth) @@ -751,6 +771,7 @@ def __init__( scale_mod_only=scale_mod_only, swiglu=swiglu, qk_norm=qk_norm, + x_block_self_attn=(i in self.x_block_self_attn_layers), ) for i in range(depth) ] @@ -832,7 +853,10 @@ def _basic_init(module): nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) - def cropped_pos_embed(self, h, w, device=None): + def set_pos_emb_random_crop_rate(self, rate: float): + self.pos_emb_random_crop_rate = rate + + def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False): p = self.x_embedder.patch_size # patched size h = (h + 1) // p @@ -842,8 +866,14 @@ def cropped_pos_embed(self, h, w, device=None): assert self.pos_embed_max_size is not None assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size) - top = (self.pos_embed_max_size - h) // 2 - left = (self.pos_embed_max_size - w) // 2 + + if not random_crop: + top = (self.pos_embed_max_size - h) // 2 + left = (self.pos_embed_max_size - w) // 2 + else: + top = torch.randint(0, self.pos_embed_max_size - h + 1, (1,)).item() + left = torch.randint(0, self.pos_embed_max_size - w + 1, (1,)).item() + spatial_pos_embed = self.pos_embed.reshape( 1, self.pos_embed_max_size, @@ -896,9 +926,12 @@ def forward( t: (N,) tensor of diffusion timesteps y: (N, D) tensor of class labels """ + pos_emb_random_crop = ( + False if self.pos_emb_random_crop_rate == 0.0 else torch.rand(1).item() < self.pos_emb_random_crop_rate + ) B, C, H, W = x.shape - x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device).to(dtype=x.dtype) + x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) c = self.t_embedder(t, dtype=x.dtype) # (N, D) if y is not None and self.y_embedder is not None: y = self.y_embedder(y) # (N, D) @@ -977,6 +1010,7 @@ def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT: depth=params.depth, mlp_ratio=4, qk_norm=params.qk_norm, + x_block_self_attn_layers=params.x_block_self_attn_layers, num_patches=params.num_patches, attn_mode=attn_mode, model_type=params.model_type, diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 1702e81c2..86f0c9c04 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -239,6 +239,13 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): default=0.0, help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0", ) + parser.add_argument( + "--pos_emb_random_crop_rate", + type=float, + default=0.0, + help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M" + " / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります", + ) # copy from Diffusers parser.add_argument( diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 71e50de36..1861dfbc2 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -41,20 +41,21 @@ def analyze_state_dict_state(state_dict: Dict, prefix: str = ""): # x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])) x_block_self_attn_layers = [] - re_attn = re.compile(r".(\d+).x_block.attn2.ln_k.weight") + re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight") for key in list(state_dict.keys()): - m = re_attn.match(key) + m = re_attn.search(key) if m: x_block_self_attn_layers.append(int(m.group(1))) - assert len(x_block_self_attn_layers) == 0, "x_block_self_attn_layers is not supported" - context_embedder_in_features = context_shape[1] context_embedder_out_features = context_shape[0] - # only supports 3-5-large and 3-medium + # only supports 3-5-large, medium or 3-medium if qk_norm is not None: - model_type = "3-5-large" + if len(x_block_self_attn_layers) == 0: + model_type = "3-5-large" + else: + model_type = "3-5-medium" else: model_type = "3-medium" diff --git a/sd3_train.py b/sd3_train.py index cdac945e6..df2736901 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -353,17 +353,15 @@ def train(args): accelerator.wait_for_everyone() # load MMDIT - mmdit = sd3_utils.load_mmdit( - sd3_state_dict, - model_dtype, - "cpu", - ) + mmdit = sd3_utils.load_mmdit(sd3_state_dict, model_dtype, "cpu") # attn_mode = "xformers" if args.xformers else "torch" # assert ( # attn_mode == "torch" # ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) + if args.gradient_checkpointing: mmdit.enable_gradient_checkpointing() diff --git a/sd3_train_network.py b/sd3_train_network.py index 3506404ae..3d2a75710 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -65,6 +65,7 @@ def load_target_model(self, args, weight_dtype, accelerator): ) mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu") self.model_type = mmdit.model_type + mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) if args.fp8_base: # check dtype of model From 70a179e446219b66f208e4fbb37b74c5d77d6086 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 30 Oct 2024 14:34:19 +0900 Subject: [PATCH 209/348] Fix to use SDPA instead of xformers --- library/sd3_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 840f91869..60356e82c 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -645,7 +645,7 @@ def _forward(self, context, x, c): if self.x_block.x_block_self_attn: x_q2, x_k2, x_v2 = x_qkv2 - attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads) + attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads, mode=self.mode) x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates) else: x = self.x_block.post_attention(x_attn_out, *x_intermediates) From 1434d8506f3ccc4ae6cc005a19531dba3cbb9fb9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 31 Oct 2024 19:58:22 +0900 Subject: [PATCH 210/348] Support SD3.5M multi resolutional training --- library/sd3_models.py | 177 ++++++++++++++++++++++++++++++++++++- library/sd3_train_utils.py | 6 ++ library/strategy_base.py | 2 +- library/strategy_flux.py | 4 +- library/strategy_sd3.py | 11 ++- library/train_util.py | 3 + sd3_train.py | 9 +- sd3_train_network.py | 13 ++- 8 files changed, 215 insertions(+), 10 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 60356e82c..0eca94e2f 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -88,6 +88,78 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): return emb +def get_scaled_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, sample_size=64, base_size=16): + """ + This function is contributed by KohakuBlueleaf. Thanks for the contribution! + + Creates scaled 2D sinusoidal positional embeddings that maintain consistent relative positions + when the resolution differs from the training resolution. + + Args: + embed_dim (int): Dimension of the positional embedding. + grid_size (int or tuple): Size of the position grid (H, W). If int, assumes square grid. + cls_token (bool): Whether to include class token. Defaults to False. + extra_tokens (int): Number of extra tokens (e.g., cls_token). Defaults to 0. + sample_size (int): Reference resolution (typically training resolution). Defaults to 64. + base_size (int): Base grid size used during training. Defaults to 16. + + Returns: + numpy.ndarray: Positional embeddings of shape (H*W, embed_dim) or + (H*W + extra_tokens, embed_dim) if cls_token is True. + """ + # Convert grid_size to tuple if it's an integer + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + # Create normalized grid coordinates (0 to 1) + grid_h = np.arange(grid_size[0], dtype=np.float32) / grid_size[0] + grid_w = np.arange(grid_size[1], dtype=np.float32) / grid_size[1] + + # Calculate scaling factors for height and width + # This ensures that the central region matches the original resolution's embeddings + scale_h = base_size * grid_size[0] / (sample_size) + scale_w = base_size * grid_size[1] / (sample_size) + + # Calculate shift values to center the original resolution's embedding region + # This ensures that the central sample_size x sample_size region has similar + # positional embeddings to the original resolution + shift_h = 1 * scale_h * (grid_size[0] - sample_size) / (2 * grid_size[0]) + shift_w = 1 * scale_w * (grid_size[1] - sample_size) / (2 * grid_size[1]) + + # Apply scaling and shifting to create the final grid coordinates + grid_h = grid_h * scale_h - shift_h + grid_w = grid_w * scale_w - shift_w + + # Create 2D grid using meshgrid (note: w goes first) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + + # # Calculate the starting indices for the central region + # # This is used for debugging/visualization of the central region + # st_h = (grid_size[0] - sample_size) // 2 + # st_w = (grid_size[1] - sample_size) // 2 + # print(grid[:, st_h : st_h + sample_size, st_w : st_w + sample_size]) + + # Reshape grid for positional embedding calculation + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + + # Generate the sinusoidal positional embeddings + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + + # Add zeros for extra tokens (e.g., [CLS] token) if required + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + + return pos_embed + + +# if __name__ == "__main__": +# # This is what you get when you load SD3.5 state dict +# pos_emb = torch.from_numpy(get_scaled_2d_sincos_pos_embed( +# 1536, [384, 384], sample_size=64, base_size=16 +# )).float().unsqueeze(0) + + def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position @@ -617,7 +689,7 @@ def __init__(self, *args, **kwargs): self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs) self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs) - + self.head_dim = self.x_block.attn.head_dim self.mode = self.x_block.attn_mode self.gradient_checkpointing = False @@ -669,6 +741,9 @@ class MMDiT(nn.Module): Diffusion model with a Transformer backbone. """ + # prepare pos_embed for latent size * 2 + POS_EMBED_MAX_RATIO = 1.5 + def __init__( self, input_size: int = 32, @@ -697,6 +772,8 @@ def __init__( x_block_self_attn_layers: Optional[list[int]] = [], qkv_bias: bool = True, pos_emb_random_crop_rate: float = 0.0, + use_scaled_pos_embed: bool = False, + pos_embed_latent_sizes: Optional[list[int]] = None, model_type: str = "sd3m", ): super().__init__() @@ -722,6 +799,8 @@ def __init__( self.num_heads = num_heads + self.enable_scaled_pos_embed(use_scaled_pos_embed, pos_embed_latent_sizes) + self.x_embedder = PatchEmbed( input_size, patch_size, @@ -785,6 +864,43 @@ def __init__( self.blocks_to_swap = None self.thread_pool: Optional[ThreadPoolExecutor] = None + def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]): + self.use_scaled_pos_embed = use_scaled_pos_embed + + if self.use_scaled_pos_embed: + # # remove pos_embed to free up memory up to 0.4 GB + self.pos_embed = None + + # sort latent sizes in ascending order + latent_sizes = sorted(latent_sizes) + + patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes] + + # calculate value range for each latent area: this is used to determine the pos_emb size from the latent shape + max_areas = [] + for i in range(1, len(patched_sizes)): + prev_area = patched_sizes[i - 1] ** 2 + area = patched_sizes[i] ** 2 + max_areas.append((prev_area + area) // 2) + + # area of the last latent size, if the latent size exceeds this, error will be raised + max_areas.append(int((patched_sizes[-1] * MMDiT.POS_EMBED_MAX_RATIO) ** 2)) + # print("max_areas", max_areas) + + self.resolution_area_to_latent_size = [(area, latent_size) for area, latent_size in zip(max_areas, patched_sizes)] + + self.resolution_pos_embeds = {} + for patched_size in patched_sizes: + grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) + self.resolution_pos_embeds[patched_size] = pos_embed + # print(f"pos_embed for {patched_size}x{patched_size} latent size: {pos_embed.shape}") + + else: + self.resolution_area_to_latent_size = None + self.resolution_pos_embeds = None + @property def model_type(self): return self._model_type @@ -884,6 +1000,54 @@ def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False): spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) return spatial_pos_embed + def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: bool = False): + p = self.x_embedder.patch_size + # patched size + h = (h + 1) // p + w = (w + 1) // p + + # select pos_embed size based on area + area = h * w + patched_size = None + for area_, patched_size_ in self.resolution_area_to_latent_size: + if area <= area_: + patched_size = patched_size_ + break + if patched_size is None: + raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") + + pos_embed_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + if h > pos_embed_size or w > pos_embed_size: + # fallback to normal pos_embed + logger.warning( + f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." + ) + return self.cropped_pos_embed(h, w, device=device, random_crop=random_crop) + + if not random_crop: + top = (pos_embed_size - h) // 2 + left = (pos_embed_size - w) // 2 + else: + top = torch.randint(0, pos_embed_size - h + 1, (1,)).item() + left = torch.randint(0, pos_embed_size - w + 1, (1,)).item() + + pos_embed = self.resolution_pos_embeds[patched_size] + if pos_embed.device != device: + pos_embed = pos_embed.to(device) + # which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device. + self.resolution_pos_embeds[patched_size] = pos_embed # update device + if pos_embed.dtype != dtype: + pos_embed = pos_embed.to(dtype) + self.resolution_pos_embeds[patched_size] = pos_embed # update dtype + + spatial_pos_embed = pos_embed.reshape(1, pos_embed_size, pos_embed_size, pos_embed.shape[-1]) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + # print( + # f"patched size: {h}x{w}, pos_embed size: {pos_embed_size}, pos_embed shape: {pos_embed.shape}, top: {top}, left: {left}" + # ) + return spatial_pos_embed + def enable_block_swap(self, num_blocks: int): self.blocks_to_swap = num_blocks @@ -931,7 +1095,16 @@ def forward( ) B, C, H, W = x.shape - x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) + + # x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) + if not self.use_scaled_pos_embed: + pos_embed = self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) + else: + # print(f"Using scaled pos_embed for size {H}x{W}") + pos_embed = self.cropped_scaled_pos_embed(H, W, device=x.device, dtype=x.dtype, random_crop=pos_emb_random_crop) + x = self.x_embedder(x) + pos_embed + del pos_embed + c = self.t_embedder(t, dtype=x.dtype) # (N, D) if y is not None and self.y_embedder is not None: y = self.y_embedder(y) # (N, D) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 86f0c9c04..69878750e 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -246,6 +246,12 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M" " / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります", ) + parser.add_argument( + "--enable_scaled_pos_embed", + action="store_true", + help="Scale position embeddings for each resolution during multi-resolution training. Only for SD3.5M" + " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります", + ) # copy from Diffusers parser.add_argument( diff --git a/library/strategy_base.py b/library/strategy_base.py index e390c5f35..358e42f1d 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -518,7 +518,7 @@ def load_latents_from_disk( self, npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: """ - for SD/SDXL/SD3.0 + for SD/SDXL """ return self._default_load_latents_from_disk(None, npz_path, bucket_reso) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index f662b62e9..5e65927f8 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -212,7 +212,7 @@ def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, True) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) def load_latents_from_disk( self, npz_path: str, bucket_reso: Tuple[int, int] @@ -226,7 +226,7 @@ def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask vae_dtype = vae.dtype self._default_cache_batch_latents( - encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, True + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True ) if not train_util.HIGH_VRAM: diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index 413169ecc..1d55fe21d 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -399,7 +399,12 @@ def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -407,7 +412,9 @@ def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask vae_device = vae.device vae_dtype = vae.dtype - self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + self._default_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True + ) if not train_util.HIGH_VRAM: train_util.clean_memory_on_device(vae.device) diff --git a/library/train_util.py b/library/train_util.py index d568523ca..bd2ff6ef4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2510,6 +2510,9 @@ def verify_bucket_reso_steps(self, min_steps: int): for dataset in self.datasets: dataset.verify_bucket_reso_steps(min_steps) + def get_resolutions(self) -> List[Tuple[int, int]]: + return [(dataset.width, dataset.height) for dataset in self.datasets] + def is_latent_cacheable(self) -> bool: return all([dataset.is_latent_cacheable() for dataset in self.datasets]) diff --git a/sd3_train.py b/sd3_train.py index df2736901..40f8c7e1f 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -361,7 +361,14 @@ def train(args): # ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) - + + # set resolutions for positional embeddings + if args.enable_scaled_pos_embed: + resolutions = train_dataset_group.get_resolutions() + latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent + logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}") + mmdit.enable_scaled_pos_embed(True, latent_sizes) + if args.gradient_checkpointing: mmdit.enable_gradient_checkpointing() diff --git a/sd3_train_network.py b/sd3_train_network.py index 3d2a75710..9eeac05ca 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -26,8 +26,8 @@ def __init__(self): super().__init__() self.sample_prompts_te_outputs = None - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) + def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup): + # super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) if args.fp8_base_unet: @@ -53,6 +53,9 @@ def assert_extra_args(self, args, train_dataset_group): train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + # enumerate resolutions from dataset for positional embeddings + self.resolutions = train_dataset_group.get_resolutions() + def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models @@ -67,6 +70,12 @@ def load_target_model(self, args, weight_dtype, accelerator): self.model_type = mmdit.model_type mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) + # set resolutions for positional embeddings + if args.enable_scaled_pos_embed: + latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent + logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}") + mmdit.enable_scaled_pos_embed(True, latent_sizes) + if args.fp8_base: # check dtype of model if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz: From 9e23368e3d6288e85c6fe34f4d5774bd4d948517 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 31 Oct 2024 19:58:41 +0900 Subject: [PATCH 211/348] Update SD3 training --- README.md | 195 +++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 163 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index ad2791e7f..aff78b2c6 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ This repository contains training, generation and utility scripts for Stable Diffusion. -## FLUX.1 training (WIP) +## FLUX.1 and SD3 training (WIP) This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. @@ -9,8 +9,15 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +- [FLUX.1 training](#flux1-training) +- [SD3 training](#sd3-training) + ### Recent Updates +Oct 31, 2024: + +- Added support for SD3.5L/M training. See [SD3 training](#sd3-training) for details. + Oct 19, 2024: - Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. SD1/2 is not tested yet. This is an experimental feature. @@ -139,7 +146,7 @@ Sep 1, 2024: Aug 29, 2024: Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated. -### Contents +## FLUX.1 training - [FLUX.1 LoRA training](#flux1-lora-training) - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) @@ -586,53 +593,177 @@ python tools/convert_diffusers_to_flux.py --diffusers_path path/to/diffusers_fol ## SD3 training -SD3 training is done with `sd3_train.py`. +SD3.5L/M training is now available. + +### SD3 LoRA training + +The script is `sd3_train_network.py`. See `--help` for options. + +SD3 model, CLIP-L, CLIP-G, and T5XXL models are recommended to be in float/fp16 format. If you specify `--fp8_base`, you can use fp8 models for SD3. The fp8 model is only compatible with `float8_e4m3fn` format. + +Sample command is below. It will work with 16GB VRAM GPUs (SD3.5L). + +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 sd3_train_network.py +--pretrained_model_name_or_path path/to/sd3.5_large.safetensors --clip_l sd3/clip_l.safetensors --clip_g sd3/clip_g.safetensors --t5xxl sd3/t5xxl_fp16.safetensors +--cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers +--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 +--network_module networks.lora_sd3 --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 +--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base +--highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml +--output_dir path/to/output/dir --output_name sd3-lora-name +``` +(The command is multi-line for readability. Please combine it into one line.) + +The training can be done with 12GB VRAM GPUs with Adafactor optimizer. Please use settings like below: + +``` +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 +``` + +`--cpu_offload_checkpointing` and `--split_mode` are not available for SD3 LoRA training. -__Sep 1, 2024__: -- `--num_last_block_to_freeze` is added to `sd3_train.py`. This option is to freeze the last n blocks of the MMDiT. See [#1417](https://github.com/kohya-ss/sd-scripts/pull/1417) for details. Thanks to sdbds! +We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. -__Jul 27, 2024__: -- Latents and text encoder outputs caching mechanism is refactored significantly. - - Existing cache files for SD3 need to be recreated. Please delete the previous cache files. - - With this change, dataset initialization is significantly faster, especially for large datasets. +The trained LoRA model can be used with ComfyUI. -- Architecture-dependent parts are extracted from the dataset (`train_util.py`). This is expected to make it easier to add future architectures. +#### Key Options for SD3 LoRA training -- Architecture-dependent parts including the cache mechanism for SD1/2/SDXL are also extracted. The basic operation of SD1/2/SDXL training on the sd3 branch has been confirmed, but there may be bugs. Please use the main or dev branch for SD1/2/SDXL training. +Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome. ---- +- `--network_module` is the module for LoRA training. Specify `networks.lora_sd3` for SD3 LoRA training. +- `--pretrained_model_name_or_path` is the path to the pretrained model (SD3/3.5). If you specify `--fp8_base`, you can use fp8 models for SD3/3.5. The fp8 model is only compatible with `float8_e4m3fn` format. +- `--clip_l` is the path to the CLIP-L model. +- `--clip_g` is the path to the CLIP-G model. +- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching. +- `--vae` is the path to the autoencoder model. __This option is not necessary for SD3.__ VAE is included in the standard SD3 model. +- `--disable_mmap_load_safetensors` is to disable memory mapping when loading safetensors. __This option significantly reduces the memory usage when loading models for Windows users.__ +- `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are the dropout rates for the embeddings of CLIP-L, CLIP-G, and T5XXL, described in [SAI research papre](http://arxiv.org/pdf/2403.03206). The default is 0.0. For LoRA training, it is seems to be better to set 0.0. +- `--pos_emb_random_crop_rate` is the rate of random cropping of positional embeddings, described in [SD3.5M model card](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). The default is 0. It is seems to be better to set 0.0 for LoRA training. +- `--enable_scaled_pos_embed` is to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. Details are described below. -`fp16` and `bf16` are available for mixed precision training. We are not sure which is better. +Other options are described below. -`optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently. +#### Key Features for SD3 LoRA training -`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. +1. CLIP-L, G and T5XXL LoRA Support: + - SD3 LoRA training now supports CLIP-L, CLIP-G and T5XXL LoRA training. + - Remove `--network_train_unet_only` from your command. + - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L and G is also trained at the same time. + - T5XXL output can be cached for CLIP-L and G LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - The learning rates for CLIP-L, CLIP-G and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5 5e-6`. The first value is the learning rate for CLIP-L, the second value is for CLIP-G, and the third value is for T5XXL. If you specify only one, the learning rates for CLIP-L, CLIP-G and T5XXL will be the same. If the third value is not specified, the second value is used for T5XXL. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. + - The trained LoRA can be used with ComfyUI. -t5xxl works with `fp16` now. + | trained LoRA|option|network_args|cache_text_encoder_outputs (*1)| + |---|---|---|---| + |MMDiT|`--network_train_unet_only`|-|o| + |MMDiT + CLIP-L + CLIP-G|-|-|o (*2)| + |MMDiT + CLIP-L + CLIP-G + T5XXL|-|`train_t5xxl=True`|-| + |CLIP-L + CLIP-G (*3)|`--network_train_text_encoder_only`|-|o (*2)| + |CLIP-L + CLIP-G + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-| -There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. + - *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - *2: T5XXL output can be cached for CLIP-L and G LoRA training. + - *3: Not tested yet. + +2. Experimental FP8/FP16 mixed training: + - `--fp8_base_unet` enables training with fp8 for MMDiT and bf16/fp16 for CLIP-L/G/T5XXL. + - When specifying this option, the `--fp8_base` option is automatically enabled. -`text_encoder_batch_size` is added experimentally for caching faster. +3. Split Q/K/V Projection Layers (Experimental): + - Same as FLUX.1. + +4. CLIP-L/G and T5 Attention Mask Application: + - This function is planned to be implemented in the future. + +5. Multi-resolution Training Support: + - Only for SD3.5M. + - Same as FLUX.1 for data preparation. + - If you train with multiple resolutions, specify `--enable_scaled_pos_embed` to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. -```toml -learning_rate = 1e-6 # seems to depend on the batch size -optimizer_type = "adafactor" -optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] -cache_text_encoder_outputs = true -cache_text_encoder_outputs_to_disk = true -vae_batch_size = 1 -text_encoder_batch_size = 4 -cache_latents = true -cache_latents_to_disk = true + +Technical details of multi-resolution training for SD3.5M: + +The values of the positional embeddings must be the same for each resolution. That is, the same value must be in the same position for 512x512, 768x768, and 1024x1024. To achieve this, the positional embeddings for each resolution are calculated in advance and switched according to the resolution of the training data. This feature is enabled by `--enable_scaled_pos_embed`. + +This idea and the code for calculating scaled positional embeddings are contributed by KohakuBlueleaf. Thanks to KohakuBlueleaf! + + +#### Specify rank for each layer in SD3 LoRA + +You can specify the rank for each layer in SD3 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer. + +When network_args is not specified, the default value (`network_dim`) is applied, same as before. + +|network_args|target layer| +|---|---| +|context_attn_dim|attn in context_block| +|context_mlp_dim|mlp in context_block| +|context_mod_dim|adaLN_modulation in context_block| +|x_attn_dim|attn in x_block| +|x_mlp_dim|mlp in x_block| +|x_mod_dim|adaLN_modulation in x_block| + +`"verbose=True"` is also available for debugging. It shows the rank of each layer. + +example: ``` +--network_args "context_attn_dim=2" "context_mlp_dim=3" "context_mod_dim=4" "x_attn_dim=5" "x_mlp_dim=6" "x_mod_dim=7" "verbose=True" +``` + +You can apply LoRA to the conditioning layers of SD3 by specifying `emb_dims` in network_args. When specifying, be sure to specify 6 numbers in `[]` as a comma-separated list. + +example: +``` +--network_args "emb_dims=[2,3,4,5,6,7]" +``` + +Each number corresponds to `context_embedder`, `t_embedder`, `x_embedder`, `y_embedder`, `final_layer_adaLN_modulation`, `final_layer_linear`. The above example applies LoRA to all conditioning layers, with rank 2 for `context_embedder`, 3 for `t_embedder`, 4 for `context_embedder`, 5 for `y_embedder`, 6 for `final_layer_adaLN_modulation`, and 7 for `final_layer_linear`. + +If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,4,0,0]` applies LoRA only to `context_embedder` and `y_embedder`. + +#### Specify blocks to train in SD3 LoRA training + +You can specify the blocks to train in SD3 LoRA training by specifying `train_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`. + +The number of blocks depends on the model. The valid range is 0-(the number of blocks - 1). `all` is also available to train all blocks, `none` is also available to train no blocks. + +example: +``` +--network_args "train_block_indices=1,2,6-8" +``` + +### Inference for SD3 with LoRA model + +The inference script is also available. The script is `sd3_minimal_inference.py`. See `--help` for options. + +### SD3 fine-tuning + +Documentation is not available yet. Please refer to the FLUX.1 fine-tuning guide for now. The major difference are following: + +- `--clip_g` is also available for SD3 fine-tuning. +- `--timestep_sampling` `--discrete_flow_shift``--model_prediction_type` --guidance_scale` are not necessary for SD3 fine-tuning. +- Use `--vae` instead of `--ae` if necessary. __This option is not necessary for SD3.__ VAE is included in the standard SD3 model. +- `--disable_mmap_load_safetensors` is available. __This option significantly reduces the memory usage when loading models for Windows users.__ +- `--cpu_offload_checkpointing` is not available for SD3 fine-tuning. +- `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are available same as LoRA training. +- `--pos_emb_random_crop_rate` and `--enable_scaled_pos_embed` are available for SD3.5M fine-tuning. +- Training text encoders is available with `--train_text_encoder` option, similar to SDXL training. + - CLIP-L and G can be trained with `--train_text_encoder` option. Training T5XXL needs `--train_t5xxl` option. + - If you use the cached text encoder outputs for T5XXL with training CLIP-L and G, specify `--use_t5xxl_cache_only`. This option enables to use the cached text encoder outputs for T5XXL only. + - The learning rates for CLIP-L, CLIP-G and T5XXL can be specified separately. `--text_encoder_lr1`, `--text_encoder_lr2` and `--text_encoder_lr3` are available. + +### Extract LoRA from SD3 Models + +Not available yet. -__2024/7/27:__ +### Convert SD3 LoRA -Latents およびテキストエンコーダ出力のキャッシュの仕組みを大きくリファクタリングしました。SD3 用の既存のキャッシュファイルの再作成が必要になりますが、ご了承ください(以前のキャッシュファイルは削除してください)。これにより、特にデータセットの規模が大きい場合のデータセット初期化が大幅に高速化されます。 +Not available yet. -データセット (`train_util.py`) からアーキテクチャ依存の部分を切り出しました。これにより将来的なアーキテクチャ追加が容易になると期待しています。 +### Merge LoRA to SD3 checkpoint -SD1/2/SDXL のキャッシュ機構を含むアーキテクチャ依存の部分も切り出しました。sd3 ブランチの SD1/2/SDXL 学習について、基本的な動作は確認していますが、不具合があるかもしれません。SD1/2/SDXL の学習には main または dev ブランチをお使いください。 +Not available yet. --- From 830df4abcc85ffdfe08b8f97f2c8351c86149af3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 31 Oct 2024 21:39:07 +0900 Subject: [PATCH 212/348] Fix crashing if image is too tall or wide. --- library/sd3_models.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 0eca94e2f..15a5b1db4 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -868,7 +868,7 @@ def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Opti self.use_scaled_pos_embed = use_scaled_pos_embed if self.use_scaled_pos_embed: - # # remove pos_embed to free up memory up to 0.4 GB + # remove pos_embed to free up memory up to 0.4 GB self.pos_embed = None # sort latent sizes in ascending order @@ -977,7 +977,7 @@ def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False): # patched size h = (h + 1) // p w = (w + 1) // p - if self.pos_embed is None: + if self.pos_embed is None: # should not happen return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device) assert self.pos_embed_max_size is not None assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) @@ -1016,13 +1016,20 @@ def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: b if patched_size is None: raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") - pos_embed_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed = self.resolution_pos_embeds[patched_size] + pos_embed_size = round(math.sqrt(pos_embed.shape[1])) if h > pos_embed_size or w > pos_embed_size: - # fallback to normal pos_embed + # # fallback to normal pos_embed + # return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop) + # extend pos_embed size logger.warning( f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." ) - return self.cropped_pos_embed(h, w, device=device, random_crop=random_crop) + pos_embed_size = max(h, w) + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) + self.resolution_pos_embeds[patched_size] = pos_embed + logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}") if not random_crop: top = (pos_embed_size - h) // 2 @@ -1031,7 +1038,6 @@ def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: b top = torch.randint(0, pos_embed_size - h + 1, (1,)).item() left = torch.randint(0, pos_embed_size - w + 1, (1,)).item() - pos_embed = self.resolution_pos_embeds[patched_size] if pos_embed.device != device: pos_embed = pos_embed.to(device) # which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device. From 9aa6f52ac3c1866d00675daf73c7560b8b76093f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 1 Nov 2024 21:43:21 +0900 Subject: [PATCH 213/348] Fix memory leak in latent caching. bmp failed to cache --- library/train_util.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index bd2ff6ef4..18d3cf6c2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1082,6 +1082,10 @@ def submit_batch(batch, cond): info.image = info.image.result() # future to image caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop) + # remove image from memory + for info in batch: + info.image = None + # define ThreadPoolExecutor to load images in parallel max_workers = min(os.cpu_count(), len(image_infos)) max_workers = max(1, max_workers // num_processes) # consider multi-gpu @@ -1397,7 +1401,17 @@ def cache_text_encoder_outputs_common( ) def get_image_size(self, image_path): - return imagesize.get(image_path) + # return imagesize.get(image_path) + image_size = imagesize.get(image_path) + if image_size[0] <= 0: + # imagesize doesn't work for some images, so use cv2 + img = cv2.imread(image_path) + if img is not None: + image_size = (img.shape[1], img.shape[0]) + else: + logger.warning(f"failed to get image size: {image_path}") + image_size = (0, 0) + return image_size def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False): img = load_image(image_path, alpha_mask) From 82daa98fe865c30a34638acc145d6f4ea8c193db Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 1 Nov 2024 21:43:47 +0900 Subject: [PATCH 214/348] remove duplicate resolution for scaled pos embed --- library/sd3_models.py | 3 ++- sd3_train.py | 1 + sd3_train_network.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 15a5b1db4..b09a57dbd 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -871,7 +871,8 @@ def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Opti # remove pos_embed to free up memory up to 0.4 GB self.pos_embed = None - # sort latent sizes in ascending order + # remove duplcates and sort latent sizes in ascending order + latent_sizes = list(set(latent_sizes)) latent_sizes = sorted(latent_sizes) patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes] diff --git a/sd3_train.py b/sd3_train.py index 40f8c7e1f..f64e2da2c 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -366,6 +366,7 @@ def train(args): if args.enable_scaled_pos_embed: resolutions = train_dataset_group.get_resolutions() latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent + latent_sizes = list(set(latent_sizes)) # remove duplicates logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}") mmdit.enable_scaled_pos_embed(True, latent_sizes) diff --git a/sd3_train_network.py b/sd3_train_network.py index 9eeac05ca..0739e094d 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -73,6 +73,7 @@ def load_target_model(self, args, weight_dtype, accelerator): # set resolutions for positional embeddings if args.enable_scaled_pos_embed: latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent + latent_sizes = list(set(latent_sizes)) # remove duplicates logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}") mmdit.enable_scaled_pos_embed(True, latent_sizes) From e0db59695fb56e6b7f42132b70e4f828820143ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 2 Nov 2024 11:13:04 +0900 Subject: [PATCH 215/348] update multi-res training in SD3.5M --- README.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index aff78b2c6..fb087c234 100644 --- a/README.md +++ b/README.md @@ -679,12 +679,16 @@ Other options are described below. 5. Multi-resolution Training Support: - Only for SD3.5M. - Same as FLUX.1 for data preparation. - - If you train with multiple resolutions, specify `--enable_scaled_pos_embed` to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. + - If you train with multiple resolutions, you can enable the scaled positional embeddings with `--enable_scaled_pos_embed`. The default is False. __This option is an experimental feature.__ + + Technical details of multi-resolution training for SD3.5M: -The values of the positional embeddings must be the same for each resolution. That is, the same value must be in the same position for 512x512, 768x768, and 1024x1024. To achieve this, the positional embeddings for each resolution are calculated in advance and switched according to the resolution of the training data. This feature is enabled by `--enable_scaled_pos_embed`. +SD3.5M does not use scaled positional embeddings for multi-resolution training, and is trained with a single positional embedding. Therefore, this feature is very experimental. + +Generally, in multi-resolution training, the values of the positional embeddings must be the same for each resolution. That is, the same value must be in the same position for 512x512, 768x768, and 1024x1024. To achieve this, the positional embeddings for each resolution are calculated in advance and switched according to the resolution of the training data. This feature is enabled by `--enable_scaled_pos_embed`. This idea and the code for calculating scaled positional embeddings are contributed by KohakuBlueleaf. Thanks to KohakuBlueleaf! From 5e32ee26a13394fdee77149c4e96b78c58eabc5e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 2 Nov 2024 15:32:16 +0900 Subject: [PATCH 216/348] fix crashing in DDP training closes #1751 --- sd3_train.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index f64e2da2c..e03d1708b 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -838,11 +838,31 @@ def optimizer_hook(parameter: torch.Tensor): accelerator.log({}, step=0) # show model device and dtype - logger.info(f"mmdit device: {mmdit.device}, dtype: {mmdit.dtype}" if mmdit else "mmdit is None") - logger.info(f"clip_l device: {clip_l.device}, dtype: {clip_l.dtype}" if clip_l else "clip_l is None") - logger.info(f"clip_g device: {clip_g.device}, dtype: {clip_g.dtype}" if clip_g else "clip_g is None") - logger.info(f"t5xxl device: {t5xxl.device}, dtype: {t5xxl.dtype}" if t5xxl else "t5xxl is None") - logger.info(f"vae device: {vae.device}, dtype: {vae.dtype}" if vae is not None else "vae is None") + logger.info( + f"mmdit device: {accelerator.unwrap_model(mmdit).device}, dtype: {accelerator.unwrap_model(mmdit).dtype}" + if mmdit + else "mmdit is None" + ) + logger.info( + f"clip_l device: {accelerator.unwrap_model(clip_l).device}, dtype: {accelerator.unwrap_model(clip_l).dtype}" + if clip_l + else "clip_l is None" + ) + logger.info( + f"clip_g device: {accelerator.unwrap_model(clip_g).device}, dtype: {accelerator.unwrap_model(clip_g).dtype}" + if clip_g + else "clip_g is None" + ) + logger.info( + f"t5xxl device: {accelerator.unwrap_model(t5xxl).device}, dtype: {accelerator.unwrap_model(t5xxl).dtype}" + if t5xxl + else "t5xxl is None" + ) + logger.info( + f"vae device: {accelerator.unwrap_model(vae).device}, dtype: {accelerator.unwrap_model(vae).dtype}" + if vae is not None + else "vae is None" + ) loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0 From 81c0c965a24ce4f0f86dfa980f803d7616ca46d8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 5 Nov 2024 21:22:42 +0900 Subject: [PATCH 217/348] faster block swap --- flux_train.py | 107 ++++++++++---------- library/flux_models.py | 138 ++++++++++++++----------- library/utils.py | 222 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 352 insertions(+), 115 deletions(-) diff --git a/flux_train.py b/flux_train.py index 79c44d7b4..afddc897f 100644 --- a/flux_train.py +++ b/flux_train.py @@ -17,12 +17,14 @@ import os from multiprocessing import Value import time -from typing import List +from typing import List, Optional, Tuple, Union import toml from tqdm import tqdm import torch +import torch.nn as nn +from library import utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -466,45 +468,28 @@ def train(args): # memory efficient block swapping - def get_block_unit(dbl_blocks, sgl_blocks, index: int): - if index < len(dbl_blocks): - return (dbl_blocks[index],) - else: - index -= len(dbl_blocks) - index *= 2 - return (sgl_blocks[index], sgl_blocks[index + 1]) - - def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device): - def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc): - # print(f"Backward: Move block {bidx_to_cpu} to CPU") - for block in blocks_to_cpu: - block = block.to("cpu", non_blocking=True) - torch.cuda.empty_cache() - - # print(f"Backward: Move block {bidx_to_cuda} to CUDA") - for block in blocks_to_cuda: - block = block.to(dvc, non_blocking=True) - - torch.cuda.synchronize() - # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}") - return bidx_to_cpu, bidx_to_cuda - - blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu) - blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda) - - futures[block_idx_to_cuda] = thread_pool.submit( - move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device - ) + def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + # start_time = time.perf_counter() + # print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA") + utils.swap_weight_devices(block_to_cpu, block_to_cuda) + # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") + return bidx_to_cpu, bidx_to_cuda # , event + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] + + futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) - def wait_blocks_move(block_idx, futures): - if block_idx not in futures: + def wait_blocks_move(block_id, futures): + if block_id not in futures: return - # print(f"Backward: Wait for block {block_idx}") + # print(f"Backward: Wait for block {block_id}") # start_time = time.perf_counter() - future = futures.pop(block_idx) - future.result() - # print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") - # torch.cuda.synchronize() + future = futures.pop(block_id) + _, bidx_to_cuda = future.result() + assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}" + # print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s") # print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s") if args.fused_backward_pass: @@ -513,11 +498,11 @@ def wait_blocks_move(block_idx, futures): library.adafactor_fused.patch_adafactor_fused(optimizer) - blocks_to_swap = args.blocks_to_swap + double_blocks_to_swap = args.blocks_to_swap // 2 + single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2 num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) - num_block_units = num_double_blocks + num_single_blocks // 2 - handled_unit_indices = set() + handled_block_ids = set() n = 1 # only asynchronous purpose, no need to increase this number # n = 2 @@ -530,28 +515,37 @@ def wait_blocks_move(block_idx, futures): if parameter.requires_grad: grad_hook = None - if blocks_to_swap: + if double_blocks_to_swap > 0 or single_blocks_to_swap > 0: is_double = param_name.startswith("double_blocks") is_single = param_name.startswith("single_blocks") - if is_double or is_single: + if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0: block_idx = int(param_name.split(".")[1]) - unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2 - if unit_idx not in handled_unit_indices: + block_id = (is_double, block_idx) # double or single, block index + if block_id not in handled_block_ids: # swap following (already backpropagated) block - handled_unit_indices.add(unit_idx) + handled_block_ids.add(block_id) # if n blocks were already backpropagated - num_blocks_propagated = num_block_units - unit_idx - 1 + if is_double: + num_blocks = num_double_blocks + blocks_to_swap = double_blocks_to_swap + else: + num_blocks = num_single_blocks + blocks_to_swap = single_blocks_to_swap + + # -1 for 0-based index, -1 for current block is not fully backpropagated yet + num_blocks_propagated = num_blocks - block_idx - 2 swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = unit_idx > 0 and unit_idx <= blocks_to_swap + waiting = block_idx > 0 and block_idx <= blocks_to_swap + if swapping or waiting: - block_idx_to_cpu = num_block_units - num_blocks_propagated + block_idx_to_cpu = num_blocks - num_blocks_propagated block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - block_idx_to_wait = unit_idx - 1 + block_idx_to_wait = block_idx - 1 # create swap hook def create_swap_grad_hook( - bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool + is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool ): def __grad_hook(tensor: torch.Tensor): if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -559,24 +553,25 @@ def __grad_hook(tensor: torch.Tensor): optimizer.step_param(tensor, param_group) tensor.grad = None - # print(f"Backward: {uidx}, {swpng}, {wtng}") + # print( + # f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}" + # ) if swpng: submit_move_blocks( futures, thread_pool, bidx_to_cpu, bidx_to_cuda, - flux.double_blocks, - flux.single_blocks, - accelerator.device, + flux.double_blocks if is_dbl else flux.single_blocks, + (is_dbl, bidx_to_cuda), # wait for this block ) if wtng: - wait_blocks_move(bidx_to_wait, futures) + wait_blocks_move((is_dbl, bidx_to_wait), futures) return __grad_hook grad_hook = create_swap_grad_hook( - block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting + is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting ) if grad_hook is None: diff --git a/library/flux_models.py b/library/flux_models.py index 0bc1c02b9..48dea4fc9 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -7,8 +7,9 @@ import math import os import time -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union +from library import utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -923,7 +924,8 @@ def __init__(self, params: FluxParams): self.blocks_to_swap = None self.thread_pool: Optional[ThreadPoolExecutor] = None - self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2 + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) @property def device(self): @@ -963,14 +965,17 @@ def disable_gradient_checkpointing(self): def enable_block_swap(self, num_blocks: int): self.blocks_to_swap = num_blocks + self.double_blocks_to_swap = num_blocks // 2 + self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2 + print( + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}." + ) n = 1 # async block swap. 1 is enough - # n = 2 - # n = max(1, os.cpu_count() // 2) self.thread_pool = ThreadPoolExecutor(max_workers=n) def move_to_device_except_swap_blocks(self, device: torch.device): - # assume model is on cpu + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage if self.blocks_to_swap: save_double_blocks = self.double_blocks save_single_blocks = self.single_blocks @@ -983,31 +988,55 @@ def move_to_device_except_swap_blocks(self, device: torch.device): self.double_blocks = save_double_blocks self.single_blocks = save_single_blocks - def get_block_unit(self, index: int): - if index < len(self.double_blocks): - return (self.double_blocks[index],) - else: - index -= len(self.double_blocks) - index *= 2 - return self.single_blocks[index], self.single_blocks[index + 1] + # def get_block_unit(self, index: int): + # if index < len(self.double_blocks): + # return (self.double_blocks[index],) + # else: + # index -= len(self.double_blocks) + # index *= 2 + # return self.single_blocks[index], self.single_blocks[index + 1] - def get_unit_index(self, is_double: bool, index: int): - if is_double: - return index - else: - return len(self.double_blocks) + index // 2 + # def get_unit_index(self, is_double: bool, index: int): + # if is_double: + # return index + # else: + # return len(self.double_blocks) + index // 2 def prepare_block_swap_before_forward(self): - # make: first n blocks are on cuda, and last n blocks are on cpu + # # make: first n blocks are on cuda, and last n blocks are on cpu + # if self.blocks_to_swap is None or self.blocks_to_swap == 0: + # # raise ValueError("Block swap is not enabled.") + # return + # for i in range(self.num_block_units - self.blocks_to_swap): + # for b in self.get_block_unit(i): + # b.to(self.device) + # for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): + # for b in self.get_block_unit(i): + # b.to("cpu") + # clean_memory_on_device(self.device) + + # all blocks are on device, but some weights are on cpu + # make first n blocks weights on device, and last n blocks weights on cpu if self.blocks_to_swap is None or self.blocks_to_swap == 0: # raise ValueError("Block swap is not enabled.") return - for i in range(self.num_block_units - self.blocks_to_swap): - for b in self.get_block_unit(i): - b.to(self.device) - for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): - for b in self.get_block_unit(i): - b.to("cpu") + + for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]: + b.to(self.device) + utils.weighs_to_device(b, self.device) # make sure weights are on device + for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]: + b.to(self.device) # move block to device first + utils.weighs_to_device(b, "cpu") # make sure weights are on cpu + torch.cuda.synchronize() + clean_memory_on_device(self.device) + + for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]: + b.to(self.device) + utils.weighs_to_device(b, self.device) # make sure weights are on device + for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]: + b.to(self.device) # move block to device first + utils.weighs_to_device(b, "cpu") # make sure weights are on cpu + torch.cuda.synchronize() clean_memory_on_device(self.device) def forward( @@ -1044,27 +1073,22 @@ def forward( for block in self.single_blocks: img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: - futures = {} - - def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda): - def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda): - # print(f"Moving {bidx_to_cpu} to cpu.") - for block in blocks_to_cpu: - block.to("cpu", non_blocking=True) - torch.cuda.empty_cache() + # device = self.device - # print(f"Moving {bidx_to_cuda} to cuda.") - for block in blocks_to_cuda: - block.to(self.device, non_blocking=True) - - torch.cuda.synchronize() + def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + start_time = time.perf_counter() + # print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.") + utils.swap_weight_devices(block_to_cpu, block_to_cuda) # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") - return block_idx_to_cpu, block_idx_to_cuda - blocks_to_cpu = self.get_block_unit(block_idx_to_cpu) - blocks_to_cuda = self.get_block_unit(block_idx_to_cuda) + # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") + return block_idx_to_cpu, block_idx_to_cuda # , event + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") - return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda) + return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) def wait_for_blocks_move(block_idx, ftrs): if block_idx not in ftrs: @@ -1073,37 +1097,35 @@ def wait_for_blocks_move(block_idx, ftrs): # start_time = time.perf_counter() ftr = ftrs.pop(block_idx) ftr.result() - # torch.cuda.synchronize() - # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") + # print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds") + double_futures = {} for block_idx, block in enumerate(self.double_blocks): # print(f"Double block {block_idx}") - unit_idx = self.get_unit_index(is_double=True, index=block_idx) - wait_for_blocks_move(unit_idx, futures) + wait_for_blocks_move(block_idx, double_futures) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if unit_idx < self.blocks_to_swap: - block_idx_to_cpu = unit_idx - block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx - future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) - futures[block_idx_to_cuda] = future + if block_idx < self.double_blocks_to_swap: + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx + future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda) + double_futures[block_idx_to_cuda] = future img = torch.cat((txt, img), 1) + single_futures = {} for block_idx, block in enumerate(self.single_blocks): # print(f"Single block {block_idx}") - unit_idx = self.get_unit_index(is_double=False, index=block_idx) - if block_idx % 2 == 0: - wait_for_blocks_move(unit_idx, futures) + wait_for_blocks_move(block_idx, single_futures) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap: - block_idx_to_cpu = unit_idx - block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx - future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) - futures[block_idx_to_cuda] = future + if block_idx < self.single_blocks_to_swap: + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_single_blocks - self.blocks_to_swap + block_idx + future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda) + single_futures[block_idx_to_cuda] = future img = img[:, txt.shape[1] :, ...] diff --git a/library/utils.py b/library/utils.py index ca0f904d2..aed510074 100644 --- a/library/utils.py +++ b/library/utils.py @@ -6,6 +6,7 @@ import struct import torch +import torch.nn as nn from torchvision import transforms from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete @@ -93,6 +94,225 @@ def setup_logging(args=None, log_level=None, reset=False): # region PyTorch utils +# def swap_weights(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): +# assert layer_to_cpu.__class__ == layer_to_cuda.__class__ +# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): +# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: +# # print(f"Swapping {layer_to_cpu.__class__.__name__}-{module_to_cpu.__class__.__name__}.") +# # cpu_tensor = module_to_cuda.weight.data +# # cuda_tensor = module_to_cpu.weight.data +# # assert cuda_tensor.device.type == "cuda" +# # temp_cpu_tensor = cuda_tensor.to("cpu", non_blocking=True) +# # torch.cuda.current_stream().synchronize() +# # cuda_tensor.copy_(cpu_tensor, non_blocking=True) +# # torch.cuda.current_stream().synchronize() +# # cpu_tensor.copy_(temp_cpu_tensor, non_blocking=True) +# # module_to_cpu.weight.data, module_to_cuda.weight.data = cpu_tensor, cuda_tensor +# cuda_tensor_view = module_to_cpu.weight.data +# cpu_tensor_view = module_to_cuda.weight.data +# module_to_cpu.weight.data = module_to_cpu.weight.to("cpu", non_blocking=True).detach().clone() +# module_to_cuda.weight.data = cuda_tensor_view +# module_to_cuda.weight.data.copy_(cpu_tensor_view) + + +def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + stream.synchronize() + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_2st(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + stream_to_cpu = torch.cuda.Stream() + stream_to_cuda = torch.cuda.Stream() + + events = [] + with torch.cuda.stream(stream_to_cpu): + # cuda to offload + offloaded_weights = [] + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True)) + event = torch.cuda.Event() + event.record(stream=stream_to_cpu) + events.append(event) + + with torch.cuda.stream(stream_to_cuda): + # cpu to cuda + for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event in zip(weight_swap_jobs, events): + event.synchronize() + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + # offload to cpu + for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), offloaded_weight in zip( + weight_swap_jobs, offloaded_weights + ): + module_to_cpu.weight.data = offloaded_weight + + stream_to_cuda.synchronize() + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_failed(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + stream_to_cpu = torch.cuda.Stream() + stream_to_cuda = torch.cuda.Stream() + + # cuda to offload + events = [] + with torch.cuda.stream(stream_to_cpu): + offloaded_weights = [] + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream_to_cpu) + offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True)) + + event = torch.cuda.Event() + event.record(stream=stream_to_cpu) + events.append(event) + + # cpu to cuda + with torch.cuda.stream(stream_to_cuda): + for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event, offloaded_weight in zip( + weight_swap_jobs, events, offloaded_weights + ): + event.synchronize() + cuda_data_view.record_stream(stream_to_cuda) + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + module_to_cpu.weight.data = offloaded_weight + + stream_to_cuda.synchronize() + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + # torch.cuda.current_stream().wait_stream(stream_to_cuda) + # for job in weight_swap_jobs: + # job[2].record_stream(torch.cuda.current_stream()) # record the ownership of the tensor + + +def swap_weight_devices_works_2(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + if not (hasattr(module_to_cpu, "offloaded_weight") or hasattr(module_to_cuda, "offloaded_weight")): + # one of the modules must have the tensor to offload + module_to_cpu.offloaded_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu") + module_to_cpu.offloaded_weight.pin_memory() + offloaded_weight = ( + module_to_cpu.offloaded_weight if hasattr(module_to_cpu, "offloaded_weight") else module_to_cuda.offloaded_weight + ) + assert module_to_cpu.weight.device.type == "cuda" and module_to_cuda.weight.device.type == "cpu" + weight_swap_jobs.append( + (module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data, offloaded_weight) + ) + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to offload + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: + cuda_data_view.record_stream(stream) + offloaded_weight.copy_(module_to_cpu.weight.data, non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + # offload to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: + module_to_cpu.weight.data = offloaded_weight + offloaded_weight = cpu_data_view + module_to_cpu.offloaded_weight = offloaded_weight + module_to_cuda.offloaded_weight = offloaded_weight + + stream.synchronize() + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_safe_works(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + if not (hasattr(module_to_cpu, "__cached_cpu_weight") or hasattr(module_to_cuda, "__cached_cuda_weight")): + # one of the modules must have the tensor to cache + module_to_cpu.__cached_cpu_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu") + module_to_cpu.__cached_cpu_weight.pin_memory() + + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + for module_to_cpu, module_to_cuda, cuda_tensor_view, cpu_tensor_view in weight_swap_jobs: + module_to_cpu.weight.data = cuda_tensor_view.to("cpu", non_blocking=True) + module_to_cuda.weight.data = cpu_tensor_view.to("cuda", non_blocking=True) + + torch.cuda.current_stream().synchronize() # wait for the copy from cache to cpu to finish + torch.cuda.empty_cache() + + +# def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): +# assert layer_to_cpu.__class__ == layer_to_cuda.__class__ +# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): +# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: +# assert module_to_cuda.weight.device.type == "cpu" and module_to_cpu.weight.device.type == "cuda" +# weight_on_cuda = module_to_cpu.weight +# weight_on_cpu = module_to_cuda.weight +# cuda_to_cpu_data = weight_on_cuda.data.to("cpu", non_blocking=True) +# event = torch.cuda.current_stream().record_event() +# event.synchronize() +# weight_on_cuda.data.copy_(weight_on_cpu.data, non_blocking=True) +# weight_on_cpu.data = cuda_to_cpu_data +# weight_on_cpu.grad, weight_on_cuda.grad = weight_on_cuda.grad, weight_on_cpu.grad + +# module_to_cpu.weight = weight_on_cpu +# module_to_cuda.weight = weight_on_cuda + + +def weighs_to_device(layer: nn.Module, device: torch.device): + for module in layer.modules(): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data = module.weight.data.to(device, non_blocking=True) + def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: """ @@ -313,6 +533,7 @@ def _convert_float8(byte_tensor, dtype_str, shape): # return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape) raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") + def load_safetensors( path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 ) -> dict[str, torch.Tensor]: @@ -336,7 +557,6 @@ def load_safetensors( return state_dict - # endregion # region Image utils From aab943cea3eb8a91041c857771f1642581133608 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 5 Nov 2024 23:27:41 +0900 Subject: [PATCH 218/348] remove unused weight swapping functions from utils.py --- library/utils.py | 185 ----------------------------------------------- 1 file changed, 185 deletions(-) diff --git a/library/utils.py b/library/utils.py index aed510074..07079c6d9 100644 --- a/library/utils.py +++ b/library/utils.py @@ -94,26 +94,6 @@ def setup_logging(args=None, log_level=None, reset=False): # region PyTorch utils -# def swap_weights(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): -# assert layer_to_cpu.__class__ == layer_to_cuda.__class__ -# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): -# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: -# # print(f"Swapping {layer_to_cpu.__class__.__name__}-{module_to_cpu.__class__.__name__}.") -# # cpu_tensor = module_to_cuda.weight.data -# # cuda_tensor = module_to_cpu.weight.data -# # assert cuda_tensor.device.type == "cuda" -# # temp_cpu_tensor = cuda_tensor.to("cpu", non_blocking=True) -# # torch.cuda.current_stream().synchronize() -# # cuda_tensor.copy_(cpu_tensor, non_blocking=True) -# # torch.cuda.current_stream().synchronize() -# # cpu_tensor.copy_(temp_cpu_tensor, non_blocking=True) -# # module_to_cpu.weight.data, module_to_cuda.weight.data = cpu_tensor, cuda_tensor -# cuda_tensor_view = module_to_cpu.weight.data -# cpu_tensor_view = module_to_cuda.weight.data -# module_to_cpu.weight.data = module_to_cpu.weight.to("cpu", non_blocking=True).detach().clone() -# module_to_cuda.weight.data = cuda_tensor_view -# module_to_cuda.weight.data.copy_(cpu_tensor_view) - def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): assert layer_to_cpu.__class__ == layer_to_cuda.__class__ @@ -143,171 +123,6 @@ def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): torch.cuda.current_stream().synchronize() # this prevents the illegal loss value -def swap_weight_devices_2st(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): - assert layer_to_cpu.__class__ == layer_to_cuda.__class__ - - weight_swap_jobs = [] - for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): - if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: - weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) - - stream_to_cpu = torch.cuda.Stream() - stream_to_cuda = torch.cuda.Stream() - - events = [] - with torch.cuda.stream(stream_to_cpu): - # cuda to offload - offloaded_weights = [] - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: - offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True)) - event = torch.cuda.Event() - event.record(stream=stream_to_cpu) - events.append(event) - - with torch.cuda.stream(stream_to_cuda): - # cpu to cuda - for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event in zip(weight_swap_jobs, events): - event.synchronize() - cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) - module_to_cuda.weight.data = cuda_data_view - - # offload to cpu - for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), offloaded_weight in zip( - weight_swap_jobs, offloaded_weights - ): - module_to_cpu.weight.data = offloaded_weight - - stream_to_cuda.synchronize() - - torch.cuda.current_stream().synchronize() # this prevents the illegal loss value - - -def swap_weight_devices_failed(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): - assert layer_to_cpu.__class__ == layer_to_cuda.__class__ - - weight_swap_jobs = [] - for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): - if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: - weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) - - stream_to_cpu = torch.cuda.Stream() - stream_to_cuda = torch.cuda.Stream() - - # cuda to offload - events = [] - with torch.cuda.stream(stream_to_cpu): - offloaded_weights = [] - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: - cuda_data_view.record_stream(stream_to_cpu) - offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True)) - - event = torch.cuda.Event() - event.record(stream=stream_to_cpu) - events.append(event) - - # cpu to cuda - with torch.cuda.stream(stream_to_cuda): - for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event, offloaded_weight in zip( - weight_swap_jobs, events, offloaded_weights - ): - event.synchronize() - cuda_data_view.record_stream(stream_to_cuda) - cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) - module_to_cuda.weight.data = cuda_data_view - - module_to_cpu.weight.data = offloaded_weight - - stream_to_cuda.synchronize() - - torch.cuda.current_stream().synchronize() # this prevents the illegal loss value - # torch.cuda.current_stream().wait_stream(stream_to_cuda) - # for job in weight_swap_jobs: - # job[2].record_stream(torch.cuda.current_stream()) # record the ownership of the tensor - - -def swap_weight_devices_works_2(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): - assert layer_to_cpu.__class__ == layer_to_cuda.__class__ - - weight_swap_jobs = [] - for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): - if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: - if not (hasattr(module_to_cpu, "offloaded_weight") or hasattr(module_to_cuda, "offloaded_weight")): - # one of the modules must have the tensor to offload - module_to_cpu.offloaded_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu") - module_to_cpu.offloaded_weight.pin_memory() - offloaded_weight = ( - module_to_cpu.offloaded_weight if hasattr(module_to_cpu, "offloaded_weight") else module_to_cuda.offloaded_weight - ) - assert module_to_cpu.weight.device.type == "cuda" and module_to_cuda.weight.device.type == "cpu" - weight_swap_jobs.append( - (module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data, offloaded_weight) - ) - - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - # cuda to offload - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: - cuda_data_view.record_stream(stream) - offloaded_weight.copy_(module_to_cpu.weight.data, non_blocking=True) - - stream.synchronize() - - # cpu to cuda - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: - cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) - module_to_cuda.weight.data = cuda_data_view - - # offload to cpu - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: - module_to_cpu.weight.data = offloaded_weight - offloaded_weight = cpu_data_view - module_to_cpu.offloaded_weight = offloaded_weight - module_to_cuda.offloaded_weight = offloaded_weight - - stream.synchronize() - - torch.cuda.current_stream().synchronize() # this prevents the illegal loss value - - -def swap_weight_devices_safe_works(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): - assert layer_to_cpu.__class__ == layer_to_cuda.__class__ - - weight_swap_jobs = [] - for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): - if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: - if not (hasattr(module_to_cpu, "__cached_cpu_weight") or hasattr(module_to_cuda, "__cached_cuda_weight")): - # one of the modules must have the tensor to cache - module_to_cpu.__cached_cpu_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu") - module_to_cpu.__cached_cpu_weight.pin_memory() - - weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) - - for module_to_cpu, module_to_cuda, cuda_tensor_view, cpu_tensor_view in weight_swap_jobs: - module_to_cpu.weight.data = cuda_tensor_view.to("cpu", non_blocking=True) - module_to_cuda.weight.data = cpu_tensor_view.to("cuda", non_blocking=True) - - torch.cuda.current_stream().synchronize() # wait for the copy from cache to cpu to finish - torch.cuda.empty_cache() - - -# def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): -# assert layer_to_cpu.__class__ == layer_to_cuda.__class__ -# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): -# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: -# assert module_to_cuda.weight.device.type == "cpu" and module_to_cpu.weight.device.type == "cuda" -# weight_on_cuda = module_to_cpu.weight -# weight_on_cpu = module_to_cuda.weight -# cuda_to_cpu_data = weight_on_cuda.data.to("cpu", non_blocking=True) -# event = torch.cuda.current_stream().record_event() -# event.synchronize() -# weight_on_cuda.data.copy_(weight_on_cpu.data, non_blocking=True) -# weight_on_cpu.data = cuda_to_cpu_data -# weight_on_cpu.grad, weight_on_cuda.grad = weight_on_cuda.grad, weight_on_cpu.grad - -# module_to_cpu.weight = weight_on_cpu -# module_to_cuda.weight = weight_on_cuda - - def weighs_to_device(layer: nn.Module, device: torch.device): for module in layer.modules(): if hasattr(module, "weight") and module.weight is not None: From 43849030cf35a7c854311e0bee9cb8a92b77dd83 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 6 Nov 2024 21:33:28 +0900 Subject: [PATCH 219/348] Fix to work without latent cache #1758 --- sd3_train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index e03d1708b..b8a0d04fa 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -885,7 +885,9 @@ def optimizer_hook(parameter: torch.Tensor): else: with torch.no_grad(): # encode images to latents. images are [-1, 1] - latents = vae.encode(batch["images"]) + latents = vae.encode(batch["images"].to(vae.device, dtype=vae.dtype)).to( + accelerator.device, dtype=weight_dtype + ) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): @@ -927,7 +929,7 @@ def optimizer_hook(parameter: torch.Tensor): if t5_out is None: _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] with torch.set_grad_enabled(train_t5xxl): - input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None + input_ids_t5xxl = input_ids_t5xxl.to("cpu") _, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] ) From 40ed54bfc0ca666c45a4a5d4b7a3064612371005 Mon Sep 17 00:00:00 2001 From: Dango233 Date: Thu, 7 Nov 2024 09:53:54 +0000 Subject: [PATCH 220/348] Simplify Timestep weighting * Remove diffusers dependency in ts & sigma calc * support Shift setting * Add uniform distribution * Default to Uniform distribution and shift 1 --- library/sd3_train_utils.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 69878750e..bfe752d5e 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -253,12 +253,12 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります", ) - # copy from Diffusers + # Dependencies of Diffusers noise sampler has been removed for clearity. parser.add_argument( "--weighting_scheme", type=str, - default="logit_normal", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], + default="uniform", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"], help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム", ) parser.add_argument( @@ -279,8 +279,13 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): default=1.29, help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効", ) - - + parser.add_argument( + "--training_shift", + type=float, + default=1.0, + help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。", + ) + def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" if args.v_parameterization: @@ -965,14 +970,20 @@ def get_noisy_model_input_and_timesteps( logit_std=args.logit_std, mode_scale=args.mode_scale, ) - indices = (u * noise_scheduler.config.num_train_timesteps).long() - timesteps = noise_scheduler.timesteps[indices].to(device=device) + t_min = args.min_timestep if args.min_timestep is not None else 0 + t_max = args.max_timestep if args.max_timestep is not None else 1000 + shift = args.training_shift + + # weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details) + u = (u * shift) / (1 + (shift - 1) * u) - # Add noise according to flow matching. - sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + indices = (u * (t_max-t_min) + t_min).long() + timesteps = indices.to(device=device, dtype=dtype) + + # sigmas according to dlowmatching + sigmas = timesteps / 1000 + sigmas = sigmas.view(-1,1,1,1) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents return noisy_model_input, timesteps, sigmas - -# endregion From e54462a4a9cb3d01c5635f8c191d28cbccfba6e0 Mon Sep 17 00:00:00 2001 From: Dango233 Date: Thu, 7 Nov 2024 09:54:12 +0000 Subject: [PATCH 221/348] Fix SD3 trained lora loading and merging --- networks/lora_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py index efe202451..ce6d1a16f 100644 --- a/networks/lora_sd3.py +++ b/networks/lora_sd3.py @@ -601,7 +601,7 @@ def merge_to(self, text_encoders, mmdit, weights_sd, dtype=None, device=None): or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5) ): apply_text_encoder = True - elif key.startswith(LoRANetwork.LORA_PREFIX_MMDIT): + elif key.startswith(LoRANetwork.LORA_PREFIX_SD3): apply_unet = True if apply_text_encoder: From bafd10d558bf318ccd7059c2b4dce2775b5758da Mon Sep 17 00:00:00 2001 From: Dango233 Date: Thu, 7 Nov 2024 18:21:04 +0800 Subject: [PATCH 222/348] Fix typo --- library/sd3_train_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index bfe752d5e..afbe34cf5 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -980,7 +980,7 @@ def get_noisy_model_input_and_timesteps( indices = (u * (t_max-t_min) + t_min).long() timesteps = indices.to(device=device, dtype=dtype) - # sigmas according to dlowmatching + # sigmas according to flowmatching sigmas = timesteps / 1000 sigmas = sigmas.view(-1,1,1,1) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents From 5e86323f12178605c0b99bc914b4bd970900ce75 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 7 Nov 2024 21:27:12 +0900 Subject: [PATCH 223/348] Update README and clean-up the code for SD3 timesteps --- README.md | 13 ++++++++++++- library/config_util.py | 2 +- library/sd3_models.py | 2 +- library/sd3_train_utils.py | 17 +++++++++-------- sd3_train.py | 8 ++++---- sd3_train_network.py | 7 +++---- 6 files changed, 30 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index fb087c234..dba76a3c5 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,13 @@ The command to install PyTorch is as follows: ### Recent Updates +Nov 7, 2024: + +- The distribution of timesteps during SD3/3.5 training has been adjusted. This applies to both fine-tuning and LoRA training. PR [#1768](https://github.com/kohya-ss/sd-scripts/pull/1768) Thanks to Dango233! + - Previously, the side closer to noise was more sampled, but now it is uniform by default. This may improve the problem of difficulty in learning details. + - Specifically, the problem of double shifting has been fixed. The default for `--weighting_scheme` has been changed to `uniform` (the previous default was `logit_normal`). + - A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled. + Oct 31, 2024: - Added support for SD3.5L/M training. See [SD3 training](#sd3-training) for details. @@ -641,6 +648,7 @@ Here are the arguments. The arguments and sample settings are still experimental - `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are the dropout rates for the embeddings of CLIP-L, CLIP-G, and T5XXL, described in [SAI research papre](http://arxiv.org/pdf/2403.03206). The default is 0.0. For LoRA training, it is seems to be better to set 0.0. - `--pos_emb_random_crop_rate` is the rate of random cropping of positional embeddings, described in [SD3.5M model card](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). The default is 0. It is seems to be better to set 0.0 for LoRA training. - `--enable_scaled_pos_embed` is to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. Details are described below. +- `--training_shift` is the shift value for the training distribution of timesteps. The default is 1.0 (uniform distribution, no shift). If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled. Other options are described below. @@ -681,7 +689,10 @@ Other options are described below. - Same as FLUX.1 for data preparation. - If you train with multiple resolutions, you can enable the scaled positional embeddings with `--enable_scaled_pos_embed`. The default is False. __This option is an experimental feature.__ - +6. Weighting scheme and training shift: + - The weighting scheme is described in the section 3.1 of the [SD3 paper](https://arxiv.org/abs/2403.03206v1). + - The uniform distribution is the default. If you want to change the distribution, see `--help` for options. + - `--training_shift` is the shift value for the training distribution of timesteps. Technical details of multi-resolution training for SD3.5M: diff --git a/library/config_util.py b/library/config_util.py index fc1fbf46d..12d0be173 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -526,7 +526,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu secondary_separator: {subset.secondary_separator} enable_wildcard: {subset.enable_wildcard} caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} caption_prefix: {subset.caption_prefix} caption_suffix: {subset.caption_suffix} diff --git a/library/sd3_models.py b/library/sd3_models.py index b09a57dbd..89225fe4d 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -871,7 +871,7 @@ def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Opti # remove pos_embed to free up memory up to 0.4 GB self.pos_embed = None - # remove duplcates and sort latent sizes in ascending order + # remove duplicates and sort latent sizes in ascending order latent_sizes = list(set(latent_sizes)) latent_sizes = sorted(latent_sizes) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index afbe34cf5..38f3c25f4 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -253,7 +253,7 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります", ) - # Dependencies of Diffusers noise sampler has been removed for clearity. + # Dependencies of Diffusers noise sampler has been removed for clarity. parser.add_argument( "--weighting_scheme", type=str, @@ -285,7 +285,8 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): default=1.0, help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。", ) - + + def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" if args.v_parameterization: @@ -956,9 +957,10 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): return weighting -def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, device, dtype -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +# endregion + + +def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz = latents.shape[0] # Sample a random timestep for each image @@ -977,13 +979,12 @@ def get_noisy_model_input_and_timesteps( # weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details) u = (u * shift) / (1 + (shift - 1) * u) - indices = (u * (t_max-t_min) + t_min).long() + indices = (u * (t_max - t_min) + t_min).long() timesteps = indices.to(device=device, dtype=dtype) # sigmas according to flowmatching sigmas = timesteps / 1000 - sigmas = sigmas.view(-1,1,1,1) + sigmas = sigmas.view(-1, 1, 1, 1) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents return noisy_model_input, timesteps, sigmas - diff --git a/sd3_train.py b/sd3_train.py index b8a0d04fa..24ecbfb7d 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -811,8 +811,8 @@ def optimizer_hook(parameter: torch.Tensor): progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) - noise_scheduler_copy = copy.deepcopy(noise_scheduler) + # noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + # noise_scheduler_copy = copy.deepcopy(noise_scheduler) if accelerator.is_main_process: init_kwargs = {} @@ -940,11 +940,11 @@ def optimizer_hook(parameter: torch.Tensor): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) - bsz = latents.shape[0] + # bsz = latents.shape[0] # get noisy model input and timesteps noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + args, latents, noise, accelerator.device, weight_dtype ) # debug: NaN check for all inputs diff --git a/sd3_train_network.py b/sd3_train_network.py index 0739e094d..bb02c7ac7 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -275,9 +275,8 @@ def sample_images(self, accelerator, args, epoch, global_step, device, vae, toke ) def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: - # shift 3.0 is the default value - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) - self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + # this scheduler is not used in training, but used to get num_train_timesteps etc. + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift) return noise_scheduler def encode_images_to_latents(self, args, accelerator, vae, images): @@ -304,7 +303,7 @@ def get_noise_pred_and_target( # get noisy model input and timesteps noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( - args, self.noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + args, latents, noise, accelerator.device, weight_dtype ) # ensure the hidden state will require grad From f264f4091f734b4e4011257b8571ef97315a1343 Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Thu, 7 Nov 2024 21:30:31 +0900 Subject: [PATCH 224/348] Update README.md --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index dba76a3c5..9273fc8fb 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,8 @@ Nov 7, 2024: - Previously, the side closer to noise was more sampled, but now it is uniform by default. This may improve the problem of difficulty in learning details. - Specifically, the problem of double shifting has been fixed. The default for `--weighting_scheme` has been changed to `uniform` (the previous default was `logit_normal`). - A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled. + - The effect of a shift in uniform distribution is shown in the figure below. + - ![Figure_1](https://github.com/user-attachments/assets/99a72c67-adfb-4440-81d4-a718985ff350) Oct 31, 2024: @@ -693,7 +695,8 @@ Other options are described below. - The weighting scheme is described in the section 3.1 of the [SD3 paper](https://arxiv.org/abs/2403.03206v1). - The uniform distribution is the default. If you want to change the distribution, see `--help` for options. - `--training_shift` is the shift value for the training distribution of timesteps. - + - The effect of a shift in uniform distribution is shown in the figure below. + - ![Figure_1](https://github.com/user-attachments/assets/99a72c67-adfb-4440-81d4-a718985ff350) Technical details of multi-resolution training for SD3.5M: From 5eb6d209d5b28d43bf611e0934297703eb041d07 Mon Sep 17 00:00:00 2001 From: Dango233 Date: Thu, 7 Nov 2024 20:33:31 +0800 Subject: [PATCH 225/348] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9273fc8fb..fe7c506cb 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Nov 7, 2024: - The distribution of timesteps during SD3/3.5 training has been adjusted. This applies to both fine-tuning and LoRA training. PR [#1768](https://github.com/kohya-ss/sd-scripts/pull/1768) Thanks to Dango233! - Previously, the side closer to noise was more sampled, but now it is uniform by default. This may improve the problem of difficulty in learning details. - Specifically, the problem of double shifting has been fixed. The default for `--weighting_scheme` has been changed to `uniform` (the previous default was `logit_normal`). - - A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled. + - A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled (training more on image details), and if more than 1.0, the side closer to noise is more sampled (training more on overall structure). - The effect of a shift in uniform distribution is shown in the figure below. - ![Figure_1](https://github.com/user-attachments/assets/99a72c67-adfb-4440-81d4-a718985ff350) From 186aa5b97d43700706bd8e986e2d5ac3f5d4c9b7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 7 Nov 2024 22:16:05 +0900 Subject: [PATCH 226/348] fix illeagal block is swapped #1764 --- library/flux_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/flux_models.py b/library/flux_models.py index 48dea4fc9..4721fa02e 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1077,7 +1077,7 @@ def forward( def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda): def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): - start_time = time.perf_counter() + # start_time = time.perf_counter() # print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.") utils.swap_weight_devices(block_to_cpu, block_to_cuda) # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") @@ -1123,7 +1123,7 @@ def wait_for_blocks_move(block_idx, ftrs): if block_idx < self.single_blocks_to_swap: block_idx_to_cpu = block_idx - block_idx_to_cuda = self.num_single_blocks - self.blocks_to_swap + block_idx + block_idx_to_cuda = self.num_single_blocks - self.single_blocks_to_swap + block_idx future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda) single_futures[block_idx_to_cuda] = future From b3248a8eefe066e6502b535a19501363ec352974 Mon Sep 17 00:00:00 2001 From: feffy380 <114889020+feffy380@users.noreply.github.com> Date: Thu, 7 Nov 2024 14:31:05 +0100 Subject: [PATCH 227/348] fix: sort order when getting image size from cache file --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 18d3cf6c2..8b5cf214e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1887,7 +1887,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # make image path to npz path mapping npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix)) - npz_paths.sort() + npz_paths.sort(key=lambda item: item.rsplit("_", maxsplit=2)[0]) # sort by name excluding resolution and cache_suffix npz_path_index = 0 size_set_count = 0 From 8fac3c3b088699f607392694beee76bc0036c8d9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 9 Nov 2024 19:56:02 +0900 Subject: [PATCH 228/348] update README --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 87c810012..14328607e 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Nov 9, 2024: + +- Fixed an issue where the image size could not be obtained when caching latents was enabled and a specific file name existed, causing the latent size to be incorrect. See PR [#1770](https://github.com/kohya-ss/sd-scripts/pull/1770) for details. Thanks to feffy380! + Nov 7, 2024: - The distribution of timesteps during SD3/3.5 training has been adjusted. This applies to both fine-tuning and LoRA training. PR [#1768](https://github.com/kohya-ss/sd-scripts/pull/1768) Thanks to Dango233! From 26bd4540a6cc7e62100f4901507d8fa0c5a7f78b Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 11 Nov 2024 09:25:28 +0800 Subject: [PATCH 229/348] init --- library/train_util.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 8b5cf214e..7f396d36e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1405,11 +1405,11 @@ def get_image_size(self, image_path): image_size = imagesize.get(image_path) if image_size[0] <= 0: # imagesize doesn't work for some images, so use cv2 - img = cv2.imread(image_path) - if img is not None: - image_size = (img.shape[1], img.shape[0]) - else: - logger.warning(f"failed to get image size: {image_path}") + try: + with Image.open(image_path) as img: + image_size = img.size + except Exception as e: + logger.warning(f"failed to get image size: {image_path}, error: {e}") image_size = (0, 0) return image_size From 02bd76e6c719ad85c108a177405846c5c958bd78 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 11 Nov 2024 21:15:36 +0900 Subject: [PATCH 230/348] Refactor block swapping to utilize custom offloading utilities --- flux_train.py | 228 ++++++++--------------------- library/custom_offloading_utils.py | 216 +++++++++++++++++++++++++++ library/flux_models.py | 113 ++------------ 3 files changed, 295 insertions(+), 262 deletions(-) create mode 100644 library/custom_offloading_utils.py diff --git a/flux_train.py b/flux_train.py index afddc897f..02dede45e 100644 --- a/flux_train.py +++ b/flux_train.py @@ -295,7 +295,7 @@ def train(args): # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") - flux.enable_block_swap(args.blocks_to_swap) + flux.enable_block_swap(args.blocks_to_swap, accelerator.device) if not cache_latents: # load VAE here if not cached @@ -338,15 +338,15 @@ def train(args): # determine target layer and block index for each parameter block_type = "other" # double, single or other if np[0].startswith("double_blocks"): - block_idx = int(np[0].split(".")[1]) + block_index = int(np[0].split(".")[1]) block_type = "double" elif np[0].startswith("single_blocks"): - block_idx = int(np[0].split(".")[1]) + block_index = int(np[0].split(".")[1]) block_type = "single" else: - block_idx = -1 + block_index = -1 - param_group_key = (block_type, block_idx) + param_group_key = (block_type, block_index) if param_group_key not in param_group: param_group[param_group_key] = [] param_group[param_group_key].append(p) @@ -466,123 +466,21 @@ def train(args): # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) - # memory efficient block swapping - - def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id): - def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): - # start_time = time.perf_counter() - # print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA") - utils.swap_weight_devices(block_to_cpu, block_to_cuda) - # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") - return bidx_to_cpu, bidx_to_cuda # , event - - block_to_cpu = blocks[block_idx_to_cpu] - block_to_cuda = blocks[block_idx_to_cuda] - - futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) - - def wait_blocks_move(block_id, futures): - if block_id not in futures: - return - # print(f"Backward: Wait for block {block_id}") - # start_time = time.perf_counter() - future = futures.pop(block_id) - _, bidx_to_cuda = future.result() - assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}" - # print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s") - # print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s") - if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - double_blocks_to_swap = args.blocks_to_swap // 2 - single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2 - num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) - num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) - handled_block_ids = set() - - n = 1 # only asynchronous purpose, no need to increase this number - # n = 2 - # n = max(1, os.cpu_count() // 2) - thread_pool = ThreadPoolExecutor(max_workers=n) - futures = {} - for param_group, param_name_group in zip(optimizer.param_groups, param_names): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: - grad_hook = None - - if double_blocks_to_swap > 0 or single_blocks_to_swap > 0: - is_double = param_name.startswith("double_blocks") - is_single = param_name.startswith("single_blocks") - if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0: - block_idx = int(param_name.split(".")[1]) - block_id = (is_double, block_idx) # double or single, block index - if block_id not in handled_block_ids: - # swap following (already backpropagated) block - handled_block_ids.add(block_id) - - # if n blocks were already backpropagated - if is_double: - num_blocks = num_double_blocks - blocks_to_swap = double_blocks_to_swap - else: - num_blocks = num_single_blocks - blocks_to_swap = single_blocks_to_swap - - # -1 for 0-based index, -1 for current block is not fully backpropagated yet - num_blocks_propagated = num_blocks - block_idx - 2 - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = block_idx > 0 and block_idx <= blocks_to_swap - - if swapping or waiting: - block_idx_to_cpu = num_blocks - num_blocks_propagated - block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - block_idx_to_wait = block_idx - 1 - - # create swap hook - def create_swap_grad_hook( - is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool - ): - def __grad_hook(tensor: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - - # print( - # f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}" - # ) - if swpng: - submit_move_blocks( - futures, - thread_pool, - bidx_to_cpu, - bidx_to_cuda, - flux.double_blocks if is_dbl else flux.single_blocks, - (is_dbl, bidx_to_cuda), # wait for this block - ) - if wtng: - wait_blocks_move((is_dbl, bidx_to_wait), futures) - - return __grad_hook - - grad_hook = create_swap_grad_hook( - is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting - ) - - if grad_hook is None: - - def __grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - grad_hook = __grad_hook + def grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None parameter.register_post_accumulate_grad_hook(grad_hook) @@ -601,66 +499,66 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} - blocks_to_swap = args.blocks_to_swap - num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) - num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) - num_block_units = num_double_blocks + num_single_blocks // 2 - - n = 1 # only asynchronous purpose, no need to increase this number - # n = max(1, os.cpu_count() // 2) - thread_pool = ThreadPoolExecutor(max_workers=n) - futures = {} - for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: - block_type, block_idx = block_types_and_indices[opt_idx] - - def create_optimizer_hook(btype, bidx): - def optimizer_hook(parameter: torch.Tensor): - # print(f"optimizer_hook: {btype}, {bidx}") - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(parameter, args.max_grad_norm) - - i = parameter_optimizer_map[parameter] - optimizer_hooked_count[i] += 1 - if optimizer_hooked_count[i] == num_parameters_per_group[i]: - optimizers[i].step() - optimizers[i].zero_grad(set_to_none=True) - - # swap blocks if necessary - if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)): - unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2 - num_blocks_propagated = num_block_units - unit_idx - - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = unit_idx > 0 and unit_idx <= blocks_to_swap - - if swapping: - block_idx_to_cpu = num_block_units - num_blocks_propagated - block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - # print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}") - submit_move_blocks( - futures, - thread_pool, - block_idx_to_cpu, - block_idx_to_cuda, - flux.double_blocks, - flux.single_blocks, - accelerator.device, - ) - - if waiting: - block_idx_to_wait = unit_idx - 1 - wait_blocks_move(block_idx_to_wait, futures) - - return optimizer_hook - - parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx)) + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 + # add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook + if is_swapping_blocks: + import library.custom_offloading_utils as custom_offloading_utils + + num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) + num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) + double_blocks_to_swap = args.blocks_to_swap // 2 + single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2 + + offloader_double = custom_offloading_utils.TrainOffloader(num_double_blocks, double_blocks_to_swap, accelerator.device) + offloader_single = custom_offloading_utils.TrainOffloader(num_single_blocks, single_blocks_to_swap, accelerator.device) + + param_name_pairs = [] + if not args.blockwise_fused_optimizers: + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + param_name_pairs.extend(zip(param_group["params"], param_name_group)) + else: + # named_parameters is a list of (name, parameter) pairs + param_name_pairs.extend([(p, n) for n, p in flux.named_parameters()]) + + for parameter, param_name in param_name_pairs: + if not parameter.requires_grad: + continue + + is_double = param_name.startswith("double_blocks") + is_single = param_name.startswith("single_blocks") + if not is_double and not is_single: + continue + + block_index = int(param_name.split(".")[1]) + if is_double: + blocks = flux.double_blocks + offloader = offloader_double + else: + blocks = flux.single_blocks + offloader = offloader_single + + grad_hook = offloader.create_grad_hook(blocks, block_index) + if grad_hook is not None: + parameter.register_post_accumulate_grad_hook(grad_hook) + # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py new file mode 100644 index 000000000..33a413004 --- /dev/null +++ b/library/custom_offloading_utils.py @@ -0,0 +1,216 @@ +from concurrent.futures import ThreadPoolExecutor +import time +from typing import Optional +import torch +import torch.nn as nn + +from library.device_utils import clean_memory_on_device + + +def synchronize_device(device: torch.device): + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + +def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + stream.synchronize() + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + """ + not tested + """ + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + # device to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + synchronize_device() + + # cpu to device + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + synchronize_device() + + +def weighs_to_device(layer: nn.Module, device: torch.device): + for module in layer.modules(): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data = module.weight.data.to(device, non_blocking=True) + + +class Offloader: + """ + common offloading class + """ + + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + self.num_blocks = num_blocks + self.blocks_to_swap = blocks_to_swap + self.device = device + self.debug = debug + + self.thread_pool = ThreadPoolExecutor(max_workers=1) + self.futures = {} + self.cuda_available = device.type == "cuda" + + def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module): + if self.cuda_available: + swap_weight_devices(block_to_cpu, block_to_cuda) + else: + swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda) + + def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + if self.debug: + start_time = time.perf_counter() + print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}") + + self.swap_weight_devices(block_to_cpu, block_to_cuda) + + if self.debug: + print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") + return bidx_to_cpu, bidx_to_cuda # , event + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] + + self.futures[block_idx_to_cuda] = self.thread_pool.submit( + move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda + ) + + def _wait_blocks_move(self, block_idx): + if block_idx not in self.futures: + return + + if self.debug: + print(f"Wait for block {block_idx}") + start_time = time.perf_counter() + + future = self.futures.pop(block_idx) + _, bidx_to_cuda = future.result() + + assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}" + + if self.debug: + print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") + + +class TrainOffloader(Offloader): + """ + supports backward offloading + """ + + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + super().__init__(num_blocks, blocks_to_swap, device, debug) + self.hook_added = set() + + def create_grad_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: + if block_index in self.hook_added: + return None + self.hook_added.add(block_index) + + # -1 for 0-based index, -1 for current block is not fully backpropagated yet + num_blocks_propagated = self.num_blocks - block_index - 2 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap + waiting = block_index > 0 and block_index <= self.blocks_to_swap + + if not swapping and not waiting: + return None + + # create hook + block_idx_to_cpu = self.num_blocks - num_blocks_propagated + block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated + block_idx_to_wait = block_index - 1 + + if self.debug: + print( + f"Backward: Created grad hook for block {block_index} with {block_idx_to_cpu}, {block_idx_to_cuda}, {block_idx_to_wait}" + ) + if swapping: + + def grad_hook(tensor: torch.Tensor): + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + + return grad_hook + + else: + + def grad_hook(tensor: torch.Tensor): + self._wait_blocks_move(block_idx_to_wait) + + return grad_hook + + +class ModelOffloader(Offloader): + """ + supports forward offloading + """ + + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + super().__init__(num_blocks, blocks_to_swap, device, debug) + + def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + + for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: + b.to(self.device) + weighs_to_device(b, self.device) # make sure weights are on device + + for b in blocks[self.num_blocks - self.blocks_to_swap :]: + b.to(self.device) # move block to device first + weighs_to_device(b, "cpu") # make sure weights are on cpu + + synchronize_device(self.device) + clean_memory_on_device(self.device) + + def wait_for_block(self, block_idx: int): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self._wait_blocks_move(block_idx) + + def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + if block_idx >= self.blocks_to_swap: + return + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) diff --git a/library/flux_models.py b/library/flux_models.py index 4721fa02e..e0bee160f 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -18,6 +18,7 @@ from einops import rearrange from torch import Tensor, nn from torch.utils.checkpoint import checkpoint +from library import custom_offloading_utils # USE_REENTRANT = True @@ -923,7 +924,8 @@ def __init__(self, params: FluxParams): self.cpu_offload_checkpointing = False self.blocks_to_swap = None - self.thread_pool: Optional[ThreadPoolExecutor] = None + self.offloader_double = None + self.offloader_single = None self.num_double_blocks = len(self.double_blocks) self.num_single_blocks = len(self.single_blocks) @@ -963,17 +965,17 @@ def disable_gradient_checkpointing(self): print("FLUX: Gradient checkpointing disabled.") - def enable_block_swap(self, num_blocks: int): + def enable_block_swap(self, num_blocks: int, device: torch.device): self.blocks_to_swap = num_blocks - self.double_blocks_to_swap = num_blocks // 2 - self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2 + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + + self.offloader_double = custom_offloading_utils.ModelOffloader(self.num_double_blocks, double_blocks_to_swap, device) + self.offloader_single = custom_offloading_utils.ModelOffloader(self.num_single_blocks, single_blocks_to_swap, device) print( - f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}." + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." ) - n = 1 # async block swap. 1 is enough - self.thread_pool = ThreadPoolExecutor(max_workers=n) - def move_to_device_except_swap_blocks(self, device: torch.device): # assume model is on cpu. do not move blocks to device to reduce temporary memory usage if self.blocks_to_swap: @@ -988,56 +990,11 @@ def move_to_device_except_swap_blocks(self, device: torch.device): self.double_blocks = save_double_blocks self.single_blocks = save_single_blocks - # def get_block_unit(self, index: int): - # if index < len(self.double_blocks): - # return (self.double_blocks[index],) - # else: - # index -= len(self.double_blocks) - # index *= 2 - # return self.single_blocks[index], self.single_blocks[index + 1] - - # def get_unit_index(self, is_double: bool, index: int): - # if is_double: - # return index - # else: - # return len(self.double_blocks) + index // 2 - def prepare_block_swap_before_forward(self): - # # make: first n blocks are on cuda, and last n blocks are on cpu - # if self.blocks_to_swap is None or self.blocks_to_swap == 0: - # # raise ValueError("Block swap is not enabled.") - # return - # for i in range(self.num_block_units - self.blocks_to_swap): - # for b in self.get_block_unit(i): - # b.to(self.device) - # for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): - # for b in self.get_block_unit(i): - # b.to("cpu") - # clean_memory_on_device(self.device) - - # all blocks are on device, but some weights are on cpu - # make first n blocks weights on device, and last n blocks weights on cpu if self.blocks_to_swap is None or self.blocks_to_swap == 0: - # raise ValueError("Block swap is not enabled.") return - - for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]: - b.to(self.device) - utils.weighs_to_device(b, self.device) # make sure weights are on device - for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]: - b.to(self.device) # move block to device first - utils.weighs_to_device(b, "cpu") # make sure weights are on cpu - torch.cuda.synchronize() - clean_memory_on_device(self.device) - - for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]: - b.to(self.device) - utils.weighs_to_device(b, self.device) # make sure weights are on device - for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]: - b.to(self.device) # move block to device first - utils.weighs_to_device(b, "cpu") # make sure weights are on cpu - torch.cuda.synchronize() - clean_memory_on_device(self.device) + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) def forward( self, @@ -1073,59 +1030,21 @@ def forward( for block in self.single_blocks: img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: - # device = self.device - - def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda): - def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): - # start_time = time.perf_counter() - # print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.") - utils.swap_weight_devices(block_to_cpu, block_to_cuda) - # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") - - # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") - return block_idx_to_cpu, block_idx_to_cuda # , event - - block_to_cpu = blocks[block_idx_to_cpu] - block_to_cuda = blocks[block_idx_to_cuda] - # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") - return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) - - def wait_for_blocks_move(block_idx, ftrs): - if block_idx not in ftrs: - return - # print(f"Waiting for move blocks: {block_idx}") - # start_time = time.perf_counter() - ftr = ftrs.pop(block_idx) - ftr.result() - # print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds") - - double_futures = {} for block_idx, block in enumerate(self.double_blocks): - # print(f"Double block {block_idx}") - wait_for_blocks_move(block_idx, double_futures) + self.offloader_double.wait_for_block(block_idx) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_idx < self.double_blocks_to_swap: - block_idx_to_cpu = block_idx - block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx - future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda) - double_futures[block_idx_to_cuda] = future + self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) img = torch.cat((txt, img), 1) - single_futures = {} for block_idx, block in enumerate(self.single_blocks): - # print(f"Single block {block_idx}") - wait_for_blocks_move(block_idx, single_futures) + self.offloader_single.wait_for_block(block_idx) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_idx < self.single_blocks_to_swap: - block_idx_to_cpu = block_idx - block_idx_to_cuda = self.num_single_blocks - self.single_blocks_to_swap + block_idx - future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda) - single_futures[block_idx_to_cuda] = future + self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) img = img[:, txt.shape[1] :, ...] From 3fe94b058a039b69b6b178bc086e200e40bfa887 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 12 Nov 2024 08:09:07 +0900 Subject: [PATCH 231/348] update comment --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 7f396d36e..a5d6fdd21 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1404,7 +1404,7 @@ def get_image_size(self, image_path): # return imagesize.get(image_path) image_size = imagesize.get(image_path) if image_size[0] <= 0: - # imagesize doesn't work for some images, so use cv2 + # imagesize doesn't work for some images, so use PIL as a fallback try: with Image.open(image_path) as img: image_size = img.size From cde90b8903870b6b28dae274d07ed27978055e3c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 12 Nov 2024 08:49:05 +0900 Subject: [PATCH 232/348] feat: implement block swapping for FLUX.1 LoRA (WIP) --- flux_train.py | 2 +- flux_train_network.py | 33 ++++++++++++++++++++++++ library/custom_offloading_utils.py | 40 +++++++++++++++++++++++++++++- library/flux_models.py | 8 ++++-- train_network.py | 9 ++++++- 5 files changed, 87 insertions(+), 5 deletions(-) diff --git a/flux_train.py b/flux_train.py index 02dede45e..346fe8fbd 100644 --- a/flux_train.py +++ b/flux_train.py @@ -519,7 +519,7 @@ def grad_hook(parameter: torch.Tensor): num_parameters_per_group[opt_idx] += 1 # add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook - if is_swapping_blocks: + if False: # is_swapping_blocks: import library.custom_offloading_utils as custom_offloading_utils num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) diff --git a/flux_train_network.py b/flux_train_network.py index 2b71a8979..376cc1597 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -25,6 +25,7 @@ def __init__(self): super().__init__() self.sample_prompts_te_outputs = None self.is_schnell: Optional[bool] = None + self.is_swapping_blocks: bool = False def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -78,6 +79,12 @@ def load_target_model(self, args, weight_dtype, accelerator): if args.split_mode: model = self.prepare_split_model(model, weight_dtype, accelerator) + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + model.enable_block_swap(args.blocks_to_swap, accelerator.device) + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) clip_l.eval() @@ -285,6 +292,8 @@ def sample_images(self, accelerator, args, epoch, global_step, device, ae, token text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) if not args.split_mode: + if self.is_swapping_blocks: + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() flux_train_utils.sample_images( accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs ) @@ -539,6 +548,19 @@ def forward(hidden_states): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + flux: flux_models.Flux = unet + flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + + return flux + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() @@ -550,6 +572,17 @@ def setup_parser() -> argparse.ArgumentParser: help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", ) + + parser.add_argument( + "--blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of blocks to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップするブロックの数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) return parser diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 33a413004..70da93902 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -183,9 +183,47 @@ class ModelOffloader(Offloader): supports forward offloading """ - def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): super().__init__(num_blocks, blocks_to_swap, device, debug) + # register backward hooks + self.remove_handles = [] + for i, block in enumerate(blocks): + hook = self.create_backward_hook(blocks, i) + if hook is not None: + handle = block.register_full_backward_hook(hook) + self.remove_handles.append(handle) + + def __del__(self): + for handle in self.remove_handles: + handle.remove() + + def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: + # -1 for 0-based index + num_blocks_propagated = self.num_blocks - block_index - 1 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap + waiting = block_index > 0 and block_index <= self.blocks_to_swap + + if not swapping and not waiting: + return None + + # create hook + block_idx_to_cpu = self.num_blocks - num_blocks_propagated + block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated + block_idx_to_wait = block_index - 1 + + def backward_hook(module, grad_input, grad_output): + if self.debug: + print(f"Backward hook for block {block_index}") + + if swapping: + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + if waiting: + self._wait_blocks_move(block_idx_to_wait) + return None + + return backward_hook + def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return diff --git a/library/flux_models.py b/library/flux_models.py index e0bee160f..4fa272522 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -970,8 +970,12 @@ def enable_block_swap(self, num_blocks: int, device: torch.device): double_blocks_to_swap = num_blocks // 2 single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 - self.offloader_double = custom_offloading_utils.ModelOffloader(self.num_double_blocks, double_blocks_to_swap, device) - self.offloader_single = custom_offloading_utils.ModelOffloader(self.num_single_blocks, single_blocks_to_swap, device) + self.offloader_double = custom_offloading_utils.ModelOffloader( + self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device #, debug=True + ) + self.offloader_single = custom_offloading_utils.ModelOffloader( + self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device #, debug=True + ) print( f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." ) diff --git a/train_network.py b/train_network.py index b90aa420e..d70f14ad3 100644 --- a/train_network.py +++ b/train_network.py @@ -18,6 +18,7 @@ init_ipex() from accelerate.utils import set_seed +from accelerate import Accelerator from diffusers import DDPMScheduler from library import deepspeed_utils, model_util, strategy_base, strategy_sd @@ -272,6 +273,11 @@ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): text_encoder.text_model.embeddings.to(dtype=weight_dtype) + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + return accelerator.prepare(unet) + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): pass @@ -627,7 +633,8 @@ def train(self, args): training_model = ds_model else: if train_unet: - unet = accelerator.prepare(unet) + # default implementation is: unet = accelerator.prepare(unet) + unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here else: unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: From 2cb7a6db02ae001355f4830581b9fc2ffffe01c6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 12 Nov 2024 21:39:13 +0900 Subject: [PATCH 233/348] feat: add block swap for FLUX.1/SD3 LoRA training --- README.md | 212 ++++++---------------------- flux_train.py | 56 +------- flux_train_network.py | 95 +++++++------ library/custom_offloading_utils.py | 75 ++++------ library/flux_models.py | 19 ++- library/flux_train_utils.py | 48 +------ library/sd3_models.py | 71 +++------- library/sd3_train_utils.py | 49 +------ library/train_util.py | 74 +++++++++- sd3_train.py | 186 +++--------------------- sd3_train_network.py | 30 ++++ tools/cache_latents.py | 1 + tools/cache_text_encoder_outputs.py | 1 + train_network.py | 6 +- 14 files changed, 291 insertions(+), 632 deletions(-) diff --git a/README.md b/README.md index 14328607e..1e63b5830 100644 --- a/README.md +++ b/README.md @@ -14,150 +14,11 @@ The command to install PyTorch is as follows: ### Recent Updates -Nov 9, 2024: +Nov 12, 2024: -- Fixed an issue where the image size could not be obtained when caching latents was enabled and a specific file name existed, causing the latent size to be incorrect. See PR [#1770](https://github.com/kohya-ss/sd-scripts/pull/1770) for details. Thanks to feffy380! - -Nov 7, 2024: - -- The distribution of timesteps during SD3/3.5 training has been adjusted. This applies to both fine-tuning and LoRA training. PR [#1768](https://github.com/kohya-ss/sd-scripts/pull/1768) Thanks to Dango233! - - Previously, the side closer to noise was more sampled, but now it is uniform by default. This may improve the problem of difficulty in learning details. - - Specifically, the problem of double shifting has been fixed. The default for `--weighting_scheme` has been changed to `uniform` (the previous default was `logit_normal`). - - A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled (training more on image details), and if more than 1.0, the side closer to noise is more sampled (training more on overall structure). - - The effect of a shift in uniform distribution is shown in the figure below. - - ![Figure_1](https://github.com/user-attachments/assets/99a72c67-adfb-4440-81d4-a718985ff350) - -Oct 31, 2024: - -- Added support for SD3.5L/M training. See [SD3 training](#sd3-training) for details. - -Oct 19, 2024: - -- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. SD1/2 is not tested yet. This is an experimental feature. - - A method to make the output of LoRA closer to the output when LoRA is not applied, with captions that do not contain trigger words. - - Define a Dataset subset for the regularization image (`is_reg = true`) with `.toml`. Add `custom_attributes.diff_output_preservation = true`. - - See [dataset configuration](docs/config_README-en.md) for the regularization dataset. - - Specify "number of training images x number of repeats >= number of regularization images x number of repeats". - - The weights of DOP is specified by `--prior_loss_weight` option (not dataset config). - - The appropriate value is still unknown. For FLUX, according to the comments in the [PR](https://github.com/kohya-ss/sd-scripts/pull/1710), the value may be 1 (thanks to dxqbYD!). For SDXL, a larger value may be needed (10-100 may be good starting points). - - It may be good to adjust the value so that the loss is about half to three-quarters of the loss when DOP is not applied. -``` -[[datasets.subsets]] -image_dir = "path/to/image/dir" -num_repeats = 1 -is_reg = true -custom_attributes.diff_output_preservation = true # Add this -``` - - -Oct 13, 2024: - -- Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large. - -- During multi-GPU training, caching of latents and Text Encoder outputs is now done in multi-GPU. - - Please make sure that `--highvram` and `--vae_batch_size` are specified correctly. If you have enough VRAM, you can increase the batch size to speed up the caching. - - `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching. - - Multi-threading is also implemented for caching of latents. This may speed up the caching process about 5% (depends on the environment). - - `tools/cache_latents.py` and `tools/cache_text_encoder_outputs.py` also have been updated to support multi-GPU caching. -- `--skip_cache_check` option is added to each training script. - - When specified, the consistency check of the cache file `*.npz` contents (e.g., image size and flip for latents, mask for Text Encoder outputs) is skipped. - - Specify this option if you have a large number of cache files and the consistency check takes time. - - Even if this option is specified, the cache will be created if the file does not exist. - - `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead. - -Oct 12, 2024 (update 1): - -- [Experimental] FLUX.1 fine-tuning and LoRA training now support "FLUX.1 __compact__" models. - - A compact model is a model that retains the FLUX.1 architecture but reduces the number of double/single blocks from the default 19/38. - - The model is automatically determined based on the keys in *.safetensors. - - Specifications for compact model safetensors: - - Please specify the block indices as consecutive numbers. An error will occur if there are missing numbers. For example, if you reduce the double blocks to 15, the maximum key will be `double_blocks.14.*`. The same applies to single blocks. - - LoRA training is unverified. - - The trained model can be used for inference with `flux_minimal_inference.py`. Other inference environments are unverified. - -Oct 12, 2024: - -- Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! - - In simple tests, SDXL and FLUX.1 LoRA training worked. FLUX.1 fine-tuning did not work, probably due to a PyTorch-related error. Other scripts are unverified. - - Set up multi-GPU training with `accelerate config`. - - Specify `--rdzv_backend=c10d` when launching `accelerate launch`. You can also edit `config.yaml` directly. - ``` - accelerate launch --rdzv_backend=c10d sdxl_train_network.py ... - ``` - - In multi-GPU training, the memory of multiple GPUs is not integrated. In other words, even if you have two 12GB VRAM GPUs, you cannot train the model that requires 24GB VRAM. Training that can be done with 12GB VRAM is executed at (up to) twice the speed. - -Oct 11, 2024: -- ControlNet training for SDXL has been implemented in this branch. Please use `sdxl_train_control_net.py`. - - For details on defining the dataset, see [here](docs/train_lllite_README.md#creating-a-dataset-configuration-file). - - The learning rate for the copy part of the U-Net is specified by `--learning_rate`. The learning rate for the added modules in ControlNet is specified by `--control_net_lr`. The optimal value is still unknown, but try around U-Net `1e-5` and ControlNet `1e-4`. - - If you want to generate sample images, specify the control image as `--cn path/to/control/image`. - - The trained weights are automatically converted and saved in Diffusers format. It should be available in ComfyUI. -- Weighting of prompts (captions) during training in SDXL is now supported (e.g., `(some text)`, `[some text]`, `(some text:1.4)`, etc.). The function is enabled by specifying `--weighted_captions`. - - The default is `False`. It is same as before, and the parentheses are used as normal text. - - If `--weighted_captions` is specified, please use `\` to escape the parentheses in the prompt. For example, `\(some text:1.4\)`. - -Oct 6, 2024: -- In FLUX.1 LoRA training and fine-tuning, the specified weight file (*.safetensors) is automatically determined to be dev or schnell. This allows schnell models to be loaded correctly. Note that LoRA training with schnell models and fine-tuning with schnell models are unverified. -- FLUX.1 LoRA training and fine-tuning can now load weights in Diffusers format in addition to BFL format (a single *.safetensors file). Please specify the parent directory of `transformer` or `diffusion_pytorch_model-00001-of-00003.safetensors` with the full path. However, Diffusers format CLIP/T5XXL is not supported. Saving is supported only in BFL format. - -Sep 26, 2024: -The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. - - -Sep 18, 2024 (update 1): -Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now. - -Sep 18, 2024: - -- Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details. - - Details of the schedule-free optimizer can be found in [facebookresearch/schedule_free](https://github.com/facebookresearch/schedule_free). - - `schedulefree` is added to the dependencies. Please update the library if necessary. - - AdamWScheduleFree or SGDScheduleFree can be used. Specify `adamwschedulefree` or `sgdschedulefree` in `--optimizer_type`. - - Wrapper classes are not available for now. - - These can be used not only for FLUX.1 training but also for other training scripts after merging to the dev/main branch. - -Sep 16, 2024: - - Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details. - -Sep 15, 2024: - -Added a script `convert_diffusers_to_flux.py` to convert Diffusers format FLUX.1 models (checkpoints) to BFL format. See `--help` for usage. Only Flux models are supported. AE/CLIP/T5XXL are not supported. - -The implementation is based on 2kpr's code. Thanks to 2kpr! - -Sep 14, 2024: -- You can now specify the rank for each layer in FLUX.1. See [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) for details. -- OFT is now supported with FLUX.1. See [FLUX.1 OFT training](#flux1-oft-training) for details. - -Sep 11, 2024: -Logging to wandb is improved. See PR [#1576](https://github.com/kohya-ss/sd-scripts/pull/1576) for details. Thanks to p1atdev! - -Sep 10, 2024: -In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. - -Sep 9, 2024: -Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. - -Sep 5, 2024 (update 1): - -Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. - -Sep 5, 2024: - -The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. - -Sep 4, 2024: -- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. See [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) for details. -- In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching. -- Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training. - -Sep 1, 2024: -- `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds! - - This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`. - -Aug 29, 2024: -Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated. +- Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM. +- During fine-tuning, the memory usage when specifying the same number of blocks has increased slightly, but the training speed when specifying block swap has been significantly improved. +- There may be bugs due to the significant changes. Feedback is welcome. ## FLUX.1 training @@ -190,7 +51,8 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --pretrained_model_name_or_path flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 ---network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 +--network_module networks.lora_flux --network_dim 4 --network_train_unet_only +--optimizer_type adamw8bit --learning_rate 1e-4 --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name @@ -198,23 +60,39 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t ``` (The command is multi-line for readability. Please combine it into one line.) -The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: +We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. + +The trained LoRA model can be used with ComfyUI. + +When training LoRA for Text Encoder (without `--network_train_unet_only`), more VRAM is required. Please refer to the settings below to reduce VRAM usage. + +__Options for GPUs with less VRAM:__ + +By specifying `--block_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. + +Specify a number like `--block_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks. + +`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--block_to_swap`. + +Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settings like below: ``` --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` -The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_mode` and `train_blocks=single` options. Please use settings like below: +The training can be done with 16GB VRAM GPUs with the batch size of 1. Please change your dataset configuration. + +The training can be done with 12GB VRAM GPUs with `--block_to_swap 16` with 8bit AdamW. Please use settings like below: ``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 +--blocks_to_swap 16 ``` -`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. +For GPUs with less than 10GB of VRAM, it is recommended to use an fp8 checkpoint for T5XXL. You can download `t5xxl_fp8_e4m3fn.safetensors` from [comfyanonymous/flux_text_encoders](https://huggingface.co/comfyanonymous/flux_text_encoders) (please use without `scaled`). -We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. +10GB VRAM GPUs will work with 22 blocks swapped, and 8GB VRAM GPUs will work with 28 blocks swapped. -The trained LoRA model can be used with ComfyUI. +__`--split_mode` is deprecated. This option is still available, but they will be removed in the future. Please use `--blocks_to_swap` instead. If this option is specified and `--blocks_to_swap` is not specified, `--blocks_to_swap 18` is automatically enabled.__ #### Key Options for FLUX.1 LoRA training @@ -239,6 +117,7 @@ There are many unknown points in FLUX.1 training, so some settings can be specif - `additive`: add to noisy input - `sigma_scaled`: apply sigma scaling, same as SD3 - `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3). +- `--blocks_to_swap`. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. The existing `--loss_type` option may be useful for FLUX.1 training. The default is `l2`. @@ -426,9 +305,9 @@ Options are almost the same as LoRA training. The difference is `--full_bf16`, ` `--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency and stochastic rounding. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. -`--blocks_to_swap` is the number of blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. The recommended maximum value is 36. +`--blocks_to_swap` is the number of blocks to swap. The default is None (no swap). The maximum value is 35. -`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. +`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. This option cannot be used with `--blocks_to_swap`. All these options are experimental and may change in the future. @@ -448,13 +327,13 @@ There are two possible ways to use block swap. It is unknown which is better. 2. Swap many blocks to increase the batch size and shorten the training speed per data. - For example, swapping 20 blocks seems to increase the batch size to about 6. In this case, the training speed per data will be relatively faster than 1. + For example, swapping 35 blocks seems to increase the batch size to about 5. In this case, the training speed per data will be relatively faster than 1. #### Training with <24GB VRAM GPUs Swap 28 blocks without cpu offload checkpointing may be working with 12GB VRAM GPUs. Please try different settings according to VRAM size of your GPU. -T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement for FLUX.1 fine-tuning. +T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement for FLUX.1 fine-tuning. #### Key Features for FLUX.1 fine-tuning @@ -465,17 +344,19 @@ T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement f - Since the transfer between CPU and GPU takes time, the training will be slower. - `--blocks_to_swap` specify the number of blocks to swap. - About 640MB of memory can be saved per block. - - Since the memory usage of one double block and two single blocks is almost the same, the transfer of single blocks is done in units of two. For example, consider the case of `--blocks_to_swap 6`. - - Before the forward pass, all double blocks and 26 (=38-12) single blocks are on the GPU. The last 12 single blocks are on the CPU. - - In the forward pass, the 6 double blocks that have finished calculation (the first 6 blocks) are transferred to the CPU, and the 12 single blocks to be calculated (the last 12 blocks) are transferred to the GPU. - - The same is true for the backward pass, but in reverse order. The 12 single blocks that have finished calculation are transferred to the CPU, and the 6 double blocks to be calculated are transferred to the GPU. - - After the backward pass, the blocks are back to their original locations. + - (Update 1: Nov 12, 2024) + - The maximum number of blocks that can be swapped is 35. + - We are exchanging only the data of the weights (weight.data) in reference to the implementation of OneTrainer (thanks to OneTrainer). However, the mechanism of the exchange is a custom implementation. + - Since it takes time to free CUDA memory (torch.cuda.empty_cache()), we reuse the CUDA memory allocated to weight.data as it is and exchange the weights between modules. + - This shortens the time it takes to exchange weights between modules. + - Since the weights must be almost identical to be exchanged, FLUX.1 exchanges the weights between double blocks and single blocks. + - In SD3, all blocks are similar, but some weights are different, so there are weights that always remain on the GPU. 2. Sample Image Generation: - Sample image generation during training is now supported. - The prompts are cached and used for generation if `--cache_latents` is specified. So changing the prompts during training will not affect the generated images. - Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. - - Note: It will be very slow when `--split_mode` is specified. + - Note: It will be very slow when `--blocks_to_swap` is specified. 3. Experimental Memory-Efficient Saving: - `--mem_eff_save` option can further reduce memory consumption during model saving (about 22GB). @@ -621,20 +502,19 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 sd3_tr --pretrained_model_name_or_path path/to/sd3.5_large.safetensors --clip_l sd3/clip_l.safetensors --clip_g sd3/clip_g.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 ---network_module networks.lora_sd3 --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 +--network_module networks.lora_sd3 --network_dim 4 --network_train_unet_only +--optimizer_type adamw8bit --learning_rate 1e-4 --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name sd3-lora-name ``` (The command is multi-line for readability. Please combine it into one line.) -The training can be done with 12GB VRAM GPUs with Adafactor optimizer. Please use settings like below: +Like FLUX.1 training, the `--blocks_to_swap` option for memory reduction is available. The maximum number of blocks that can be swapped is 36 for SD3.5L and 22 for SD3.5M. -``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 -``` +Adafactor optimizer is also available. -`--cpu_offload_checkpointing` and `--split_mode` are not available for SD3 LoRA training. +`--cpu_offload_checkpointing` option is not available. We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. diff --git a/flux_train.py b/flux_train.py index 346fe8fbd..ad2c7722b 100644 --- a/flux_train.py +++ b/flux_train.py @@ -78,6 +78,10 @@ def train(args): ) args.gradient_checkpointing = True + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -518,47 +522,6 @@ def grad_hook(parameter: torch.Tensor): parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 - # add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook - if False: # is_swapping_blocks: - import library.custom_offloading_utils as custom_offloading_utils - - num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) - num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) - double_blocks_to_swap = args.blocks_to_swap // 2 - single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2 - - offloader_double = custom_offloading_utils.TrainOffloader(num_double_blocks, double_blocks_to_swap, accelerator.device) - offloader_single = custom_offloading_utils.TrainOffloader(num_single_blocks, single_blocks_to_swap, accelerator.device) - - param_name_pairs = [] - if not args.blockwise_fused_optimizers: - for param_group, param_name_group in zip(optimizer.param_groups, param_names): - param_name_pairs.extend(zip(param_group["params"], param_name_group)) - else: - # named_parameters is a list of (name, parameter) pairs - param_name_pairs.extend([(p, n) for n, p in flux.named_parameters()]) - - for parameter, param_name in param_name_pairs: - if not parameter.requires_grad: - continue - - is_double = param_name.startswith("double_blocks") - is_single = param_name.startswith("single_blocks") - if not is_double and not is_single: - continue - - block_index = int(param_name.split(".")[1]) - if is_double: - blocks = flux.double_blocks - offloader = offloader_double - else: - blocks = flux.single_blocks - offloader = offloader_single - - grad_hook = offloader.create_grad_hook(blocks, block_index) - if grad_hook is not None: - parameter.register_post_accumulate_grad_hook(grad_hook) - # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) @@ -827,6 +790,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) add_custom_train_arguments(parser) # TODO remove this from here + train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) parser.add_argument( @@ -851,16 +815,6 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) - parser.add_argument( - "--blocks_to_swap", - type=int, - default=None, - help="[EXPERIMENTAL] " - "Sets the number of blocks (~640MB) to swap during the forward and backward passes." - "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。" - "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", - ) parser.add_argument( "--double_blocks_to_swap", type=int, diff --git a/flux_train_network.py b/flux_train_network.py index 376cc1597..9bcd59282 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -52,10 +52,23 @@ def assert_extra_args(self, args, train_dataset_group): if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") - assert not args.split_mode or not args.cpu_offload_checkpointing, ( - "split_mode and cpu_offload_checkpointing cannot be used together" - " / split_modeとcpu_offload_checkpointingは同時に使用できません" - ) + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + # deprecated split_mode option + if args.split_mode: + if args.blocks_to_swap is not None: + logger.warning( + "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored." + " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。" + ) + else: + logger.warning( + "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set." + " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。" + ) + args.blocks_to_swap = 18 # 18 is safe for most cases train_dataset_group.verify_bucket_reso_steps(32) # TODO check this @@ -75,9 +88,15 @@ def load_target_model(self, args, weight_dtype, accelerator): raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") elif model.dtype == torch.float8_e4m3fn: logger.info("Loaded fp8 FLUX model") + else: + logger.info( + "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + model.to(torch.float8_e4m3fn) - if args.split_mode: - model = self.prepare_split_model(model, weight_dtype, accelerator) + # if args.split_mode: + # model = self.prepare_split_model(model, weight_dtype, accelerator) self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 if self.is_swapping_blocks: @@ -108,6 +127,7 @@ def load_target_model(self, args, weight_dtype, accelerator): return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + """ def prepare_split_model(self, model, weight_dtype, accelerator): from accelerate import init_empty_weights @@ -144,6 +164,7 @@ def prepare_split_model(self, model, weight_dtype, accelerator): logger.info("split model prepared") return flux_lower + """ def get_tokenize_strategy(self, args): _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) @@ -291,14 +312,12 @@ def sample_images(self, accelerator, args, epoch, global_step, device, ae, token text_encoders = text_encoder # for compatibility text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) - if not args.split_mode: - if self.is_swapping_blocks: - accelerator.unwrap_model(flux).prepare_block_swap_before_forward() - flux_train_utils.sample_images( - accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs - ) - return + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs + ) + # return + """ class FluxUpperLowerWrapper(torch.nn.Module): def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): super().__init__() @@ -325,6 +344,7 @@ def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_a accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs ) clean_memory_on_device(accelerator.device) + """ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) @@ -383,20 +403,21 @@ def get_noise_pred_and_target( t5_attn_mask = None def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): - if not args.split_mode: - # normal forward - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = unet( - img=img, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) + # if not args.split_mode: + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=img, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + """ else: # split forward to reduce memory usage assert network.train_blocks == "single", "train_blocks must be single for split mode" @@ -430,6 +451,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t vec.requires_grad_(True) pe.requires_grad_(True) model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + """ return model_pred @@ -558,30 +580,23 @@ def prepare_unet_with_accelerator( flux: flux_models.Flux = unet flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() return flux def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) parser.add_argument( "--split_mode", action="store_true", - help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" - + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", - ) - - parser.add_argument( - "--blocks_to_swap", - type=int, - default=None, - help="[EXPERIMENTAL] " - "Sets the number of blocks to swap during the forward and backward passes." - "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップするブロックの数を設定します。" - "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", + help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead." + " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。", ) return parser diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 70da93902..84c2b743e 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -16,13 +16,29 @@ def synchronize_device(device: torch.device): torch.mps.synchronize() -def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): +def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): assert layer_to_cpu.__class__ == layer_to_cuda.__class__ weight_swap_jobs = [] - for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): - if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: - weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules + # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + # print(module_to_cpu.__class__, module_to_cuda.__class__) + # if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + # weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()} + for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules(): + if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None: + module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None) + if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + else: + if module_to_cuda.weight.data.device.type != device.type: + # print( + # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device" + # ) + module_to_cuda.weight.data = module_to_cuda.weight.data.to(device) torch.cuda.current_stream().synchronize() # this prevents the illegal loss value @@ -92,7 +108,7 @@ def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, d def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module): if self.cuda_available: - swap_weight_devices(block_to_cpu, block_to_cuda) + swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda) else: swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda) @@ -132,52 +148,6 @@ def _wait_blocks_move(self, block_idx): print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") -class TrainOffloader(Offloader): - """ - supports backward offloading - """ - - def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): - super().__init__(num_blocks, blocks_to_swap, device, debug) - self.hook_added = set() - - def create_grad_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: - if block_index in self.hook_added: - return None - self.hook_added.add(block_index) - - # -1 for 0-based index, -1 for current block is not fully backpropagated yet - num_blocks_propagated = self.num_blocks - block_index - 2 - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap - waiting = block_index > 0 and block_index <= self.blocks_to_swap - - if not swapping and not waiting: - return None - - # create hook - block_idx_to_cpu = self.num_blocks - num_blocks_propagated - block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated - block_idx_to_wait = block_index - 1 - - if self.debug: - print( - f"Backward: Created grad hook for block {block_index} with {block_idx_to_cpu}, {block_idx_to_cuda}, {block_idx_to_wait}" - ) - if swapping: - - def grad_hook(tensor: torch.Tensor): - self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) - - return grad_hook - - else: - - def grad_hook(tensor: torch.Tensor): - self._wait_blocks_move(block_idx_to_wait) - - return grad_hook - - class ModelOffloader(Offloader): """ supports forward offloading @@ -228,6 +198,9 @@ def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return + if self.debug: + print("Prepare block devices before forward") + for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: b.to(self.device) weighs_to_device(b, self.device) # make sure weights are on device diff --git a/library/flux_models.py b/library/flux_models.py index 4fa272522..fa3c7ad2b 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -970,11 +970,16 @@ def enable_block_swap(self, num_blocks: int, device: torch.device): double_blocks_to_swap = num_blocks // 2 single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( + f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + self.offloader_double = custom_offloading_utils.ModelOffloader( - self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device #, debug=True + self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True ) self.offloader_single = custom_offloading_utils.ModelOffloader( - self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device #, debug=True + self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True ) print( f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." @@ -1061,10 +1066,11 @@ def forward( return img +""" class FluxUpper(nn.Module): - """ + "" Transformer model for flow matching on sequences. - """ + "" def __init__(self, params: FluxParams): super().__init__() @@ -1168,9 +1174,9 @@ def forward( class FluxLower(nn.Module): - """ + "" Transformer model for flow matching on sequences. - """ + "" def __init__(self, params: FluxParams): super().__init__() @@ -1228,3 +1234,4 @@ def forward( img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img +""" diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index fa673a2f0..d90644a25 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -257,14 +257,9 @@ def sample_image_inference( wandb_tracker = accelerator.get_tracker("wandb") import wandb + # not to commit images to avoid inconsistency between training and logging steps - wandb_tracker.log( - {f"sample_{i}": wandb.Image( - image, - caption=prompt # positive prompt as a caption - )}, - commit=False - ) + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption def time_shift(mu: float, sigma: float, t: torch.Tensor): @@ -324,7 +319,7 @@ def denoise( ) img = img + (t_prev - t_curr) * pred - + model.prepare_block_swap_before_forward() return img @@ -549,44 +544,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): action="store_true", help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する", ) - parser.add_argument( - "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" - ) - parser.add_argument( - "--cache_text_encoder_outputs_to_disk", - action="store_true", - help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", - ) - parser.add_argument( - "--text_encoder_batch_size", - type=int, - default=None, - help="text encoder batch size (default: None, use dataset's batch size)" - + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", - ) - parser.add_argument( - "--disable_mmap_load_safetensors", - action="store_true", - help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", - ) - # copy from Diffusers - parser.add_argument( - "--weighting_scheme", - type=str, - default="none", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - ) - parser.add_argument( - "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." - ) - parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") - parser.add_argument( - "--mode_scale", - type=float, - default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", - ) parser.add_argument( "--guidance_scale", type=float, diff --git a/library/sd3_models.py b/library/sd3_models.py index 89225fe4d..8b90205db 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -18,6 +18,7 @@ from torch.utils.checkpoint import checkpoint from transformers import CLIPTokenizer, T5TokenizerFast +from library import custom_offloading_utils from library.device_utils import clean_memory_on_device from .utils import setup_logging @@ -862,7 +863,8 @@ def __init__( # self.initialize_weights() self.blocks_to_swap = None - self.thread_pool: Optional[ThreadPoolExecutor] = None + self.offloader = None + self.num_blocks = len(self.joint_blocks) def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]): self.use_scaled_pos_embed = use_scaled_pos_embed @@ -1055,14 +1057,20 @@ def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: b # ) return spatial_pos_embed - def enable_block_swap(self, num_blocks: int): + def enable_block_swap(self, num_blocks: int, device: torch.device): self.blocks_to_swap = num_blocks - n = 1 # async block swap. 1 is enough - self.thread_pool = ThreadPoolExecutor(max_workers=n) + assert ( + self.blocks_to_swap <= self.num_blocks - 2 + ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks." + + self.offloader = custom_offloading_utils.ModelOffloader( + self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True + ) + print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.") def move_to_device_except_swap_blocks(self, device: torch.device): - # assume model is on cpu + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage if self.blocks_to_swap: save_blocks = self.joint_blocks self.joint_blocks = None @@ -1073,16 +1081,9 @@ def move_to_device_except_swap_blocks(self, device: torch.device): self.joint_blocks = save_blocks def prepare_block_swap_before_forward(self): - # make: first n blocks are on cuda, and last n blocks are on cpu if self.blocks_to_swap is None or self.blocks_to_swap == 0: - # raise ValueError("Block swap is not enabled.") return - num_blocks = len(self.joint_blocks) - for i in range(num_blocks - self.blocks_to_swap): - self.joint_blocks[i].to(self.device) - for i in range(num_blocks - self.blocks_to_swap, num_blocks): - self.joint_blocks[i].to("cpu") - clean_memory_on_device(self.device) + self.offloader.prepare_block_devices_before_forward(self.joint_blocks) def forward( self, @@ -1122,57 +1123,19 @@ def forward( if self.register_length > 0: context = torch.cat( - ( - einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), - default(context, torch.Tensor([]).type_as(x)), - ), - 1, + (einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), default(context, torch.Tensor([]).type_as(x))), 1 ) if not self.blocks_to_swap: for block in self.joint_blocks: context, x = block(context, x, c) else: - futures = {} - - def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda): - def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): - # print(f"Moving {bidx_to_cpu} to cpu.") - block_to_cpu.to("cpu", non_blocking=True) - torch.cuda.empty_cache() - - # print(f"Moving {bidx_to_cuda} to cuda.") - block_to_cuda.to(self.device, non_blocking=True) - - torch.cuda.synchronize() - # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") - return block_idx_to_cpu, block_idx_to_cuda - - block_to_cpu = self.joint_blocks[block_idx_to_cpu] - block_to_cuda = self.joint_blocks[block_idx_to_cuda] - # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") - return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) - - def wait_for_blocks_move(block_idx, ftrs): - if block_idx not in ftrs: - return - # print(f"Waiting for move blocks: {block_idx}") - # start_time = time.perf_counter() - ftr = ftrs.pop(block_idx) - ftr.result() - # torch.cuda.synchronize() - # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") - for block_idx, block in enumerate(self.joint_blocks): - wait_for_blocks_move(block_idx, futures) + self.offloader.wait_for_block(block_idx) context, x = block(context, x, c) - if block_idx < self.blocks_to_swap: - block_idx_to_cpu = block_idx - block_idx_to_cuda = len(self.joint_blocks) - self.blocks_to_swap + block_idx - future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) - futures[block_idx_to_cuda] = future + self.offloader.submit_move_blocks(self.joint_blocks, block_idx) x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify return x[:, :, :H, :W] diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 38f3c25f4..c40798846 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -142,27 +142,6 @@ def sd_saver(ckpt_file, epoch_no, global_step): def add_sd3_training_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" - ) - parser.add_argument( - "--cache_text_encoder_outputs_to_disk", - action="store_true", - help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", - ) - parser.add_argument( - "--text_encoder_batch_size", - type=int, - default=None, - help="text encoder batch size (default: None, use dataset's batch size)" - + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", - ) - parser.add_argument( - "--disable_mmap_load_safetensors", - action="store_true", - help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", - ) - parser.add_argument( "--clip_l", type=str, @@ -253,32 +232,8 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります", ) - # Dependencies of Diffusers noise sampler has been removed for clarity. - parser.add_argument( - "--weighting_scheme", - type=str, - default="uniform", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"], - help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム", - ) - parser.add_argument( - "--logit_mean", - type=float, - default=0.0, - help="mean to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合の平均", - ) - parser.add_argument( - "--logit_std", - type=float, - default=1.0, - help="std to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合のstd", - ) - parser.add_argument( - "--mode_scale", - type=float, - default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効", - ) + # Dependencies of Diffusers noise sampler has been removed for clarity in training + parser.add_argument( "--training_shift", type=float, diff --git a/library/train_util.py b/library/train_util.py index a5d6fdd21..e1dfeecdb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1887,7 +1887,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # make image path to npz path mapping npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix)) - npz_paths.sort(key=lambda item: item.rsplit("_", maxsplit=2)[0]) # sort by name excluding resolution and cache_suffix + npz_paths.sort( + key=lambda item: item.rsplit("_", maxsplit=2)[0] + ) # sort by name excluding resolution and cache_suffix npz_path_index = 0 size_set_count = 0 @@ -3537,8 +3539,8 @@ def int_or_float(value): parser.add_argument( "--fused_backward_pass", action="store_true", - help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL" - + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効", + help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL, SD3 and FLUX" + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXL、SD3、FLUXでのみ利用可能", ) parser.add_argument( "--lr_scheduler_timescale", @@ -4027,6 +4029,72 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser): ) +def add_dit_training_arguments(parser: argparse.ArgumentParser): + # Text encoder related arguments + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + parser.add_argument( + "--text_encoder_batch_size", + type=int, + default=None, + help="text encoder batch size (default: None, use dataset's batch size)" + + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", + ) + + # Model loading optimization + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", + ) + + # Training arguments. partial copy from Diffusers + parser.add_argument( + "--weighting_scheme", + type=str, + default="uniform", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none", "uniform"], + help="weighting scheme for timestep distribution. Default is uniform, uniform and none are the same behavior" + " / タイムステップ分布の重み付けスキーム、デフォルトはuniform、uniform と none は同じ挙動", + ) + parser.add_argument( + "--logit_mean", + type=float, + default=0.0, + help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均", + ) + parser.add_argument( + "--logit_std", + type=float, + default=1.0, + help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd", + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール", + ) + + # offloading + parser.add_argument( + "--blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of blocks to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップするブロックの数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) + + def get_sanitized_config_or_none(args: argparse.Namespace): # if `--log_config` is enabled, return args for logging. if not, return None. # when `--log_config is enabled, filter out sensitive values from args diff --git a/sd3_train.py b/sd3_train.py index 24ecbfb7d..a4fc2eec8 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -201,21 +201,6 @@ def train(args): # モデルを読み込む # t5xxl_dtype = weight_dtype - # if args.t5xxl_dtype is not None: - # if args.t5xxl_dtype == "fp16": - # t5xxl_dtype = torch.float16 - # elif args.t5xxl_dtype == "bf16": - # t5xxl_dtype = torch.bfloat16 - # elif args.t5xxl_dtype == "fp32" or args.t5xxl_dtype == "float": - # t5xxl_dtype = torch.float32 - # else: - # raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") - # t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device - # clip_dtype = weight_dtype # if not args.train_text_encoder else None - - # if clip_l is not specified, the checkpoint must contain clip_l, so we load state dict here - # if full_fp16/bf16, model_dtype is casted to fp16/bf16. If not, model_dtype is None (float32). - # by loading with model_dtype, we can reduce memory usage. model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) if args.clip_l is None: sd3_state_dict = utils.load_safetensors( @@ -384,7 +369,7 @@ def train(args): # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") - mmdit.enable_block_swap(args.blocks_to_swap) + mmdit.enable_block_swap(args.blocks_to_swap, accelerator.device) if not cache_latents: # move to accelerator device @@ -611,108 +596,21 @@ def train(args): # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) - # memory efficient block swapping - - def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, device): - def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda, dvc): - # print(f"Backward: Move block {bidx_to_cpu} to CPU") - block_to_cpu = block_to_cpu.to("cpu", non_blocking=True) - torch.cuda.empty_cache() - - # print(f"Backward: Move block {bidx_to_cuda} to CUDA") - block_to_cuda = block_to_cuda.to(dvc, non_blocking=True) - torch.cuda.synchronize() - # print(f"Backward: Done moving blocks {bidx_to_cpu} and {bidx_to_cuda}") - return bidx_to_cpu, bidx_to_cuda - - block_to_cpu = blocks[block_idx_to_cpu] - block_to_cuda = blocks[block_idx_to_cuda] - - futures[block_idx_to_cuda] = thread_pool.submit( - move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda, device - ) - - def wait_blocks_move(block_idx, futures): - if block_idx not in futures: - return - future = futures.pop(block_idx) - future.result() - if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - blocks_to_swap = args.blocks_to_swap - num_blocks = len(accelerator.unwrap_model(mmdit).joint_blocks) - handled_block_indices = set() - - n = 1 # only asynchronous purpose, no need to increase this number - # n = 2 - # n = max(1, os.cpu_count() // 2) - thread_pool = ThreadPoolExecutor(max_workers=n) - futures = {} - for param_group, param_name_group in zip(optimizer.param_groups, param_names): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: - grad_hook = None - - if blocks_to_swap: - is_block = param_name.startswith("joint_blocks") - if is_block: - block_idx = int(param_name.split(".")[1]) - if block_idx not in handled_block_indices: - # swap following (already backpropagated) block - handled_block_indices.add(block_idx) - - # if n blocks were already backpropagated - num_blocks_propagated = num_blocks - block_idx - 1 - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = block_idx > 0 and block_idx <= blocks_to_swap - if swapping or waiting: - block_idx_to_cpu = num_blocks - num_blocks_propagated - block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - block_idx_to_wait = block_idx - 1 - - # create swap hook - def create_swap_grad_hook( - bidx_to_cpu, bidx_to_cuda, bidx_to_wait, bidx: int, swpng: bool, wtng: bool - ): - def __grad_hook(tensor: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - - if swpng: - submit_move_blocks( - futures, - thread_pool, - bidx_to_cpu, - bidx_to_cuda, - mmdit.joint_blocks, - accelerator.device, - ) - if wtng: - wait_blocks_move(bidx_to_wait, futures) - - return __grad_hook - - grad_hook = create_swap_grad_hook( - block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, block_idx, swapping, waiting - ) - - if grad_hook is None: - - def __grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - grad_hook = __grad_hook + def grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None parameter.register_post_accumulate_grad_hook(grad_hook) @@ -731,59 +629,22 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} - blocks_to_swap = args.blocks_to_swap - num_blocks = len(accelerator.unwrap_model(mmdit).joint_blocks) - - n = 1 # only asynchronous purpose, no need to increase this number - # n = max(1, os.cpu_count() // 2) - thread_pool = ThreadPoolExecutor(max_workers=n) - futures = {} - for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: - block_type, block_idx = block_types_and_indices[opt_idx] - - def create_optimizer_hook(btype, bidx): - def optimizer_hook(parameter: torch.Tensor): - # print(f"optimizer_hook: {btype}, {bidx}") - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(parameter, args.max_grad_norm) - - i = parameter_optimizer_map[parameter] - optimizer_hooked_count[i] += 1 - if optimizer_hooked_count[i] == num_parameters_per_group[i]: - optimizers[i].step() - optimizers[i].zero_grad(set_to_none=True) - - # swap blocks if necessary - if blocks_to_swap and btype == "joint": - num_blocks_propagated = num_blocks - bidx - - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = bidx > 0 and bidx <= blocks_to_swap - - if swapping: - block_idx_to_cpu = num_blocks - num_blocks_propagated - block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - # print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}") - submit_move_blocks( - futures, - thread_pool, - block_idx_to_cpu, - block_idx_to_cuda, - mmdit.joint_blocks, - accelerator.device, - ) - - if waiting: - block_idx_to_wait = bidx - 1 - wait_blocks_move(block_idx_to_wait, futures) - - return optimizer_hook - - parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx)) + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 @@ -1130,6 +991,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) add_custom_train_arguments(parser) + train_util.add_dit_training_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) parser.add_argument( @@ -1190,16 +1052,6 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) - parser.add_argument( - "--blocks_to_swap", - type=int, - default=None, - help="[EXPERIMENTAL] " - "Sets the number of blocks (~640MB) to swap during the forward and backward passes." - "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。" - "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", - ) parser.add_argument( "--num_last_block_to_freeze", type=int, diff --git a/sd3_train_network.py b/sd3_train_network.py index bb02c7ac7..1726e325f 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -51,6 +51,10 @@ def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup): if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this # enumerate resolutions from dataset for positional embeddings @@ -83,6 +87,17 @@ def load_target_model(self, args, weight_dtype, accelerator): raise ValueError(f"Unsupported fp8 model dtype: {mmdit.dtype}") elif mmdit.dtype == torch.float8_e4m3fn: logger.info("Loaded fp8 SD3 model") + else: + logger.info( + "Cast SD3 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / SD3モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + mmdit.to(torch.float8_e4m3fn) + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + mmdit.enable_block_swap(args.blocks_to_swap, accelerator.device) clip_l = sd3_utils.load_clip_l( args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict @@ -432,9 +447,24 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) batch["text_encoder_outputs_list"] = text_encoder_outputs_list + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + mmdit: sd3_models.MMDiT = unet + mmdit = accelerator.prepare(mmdit, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward() + + return mmdit + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) return parser diff --git a/tools/cache_latents.py b/tools/cache_latents.py index e2faa58a7..c034f949a 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -164,6 +164,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_dataset_arguments(parser, True, True, True) train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) + train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 7be9ad781..5888b8e3d 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -191,6 +191,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_dataset_arguments(parser, True, True, True) train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) + train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") diff --git a/train_network.py b/train_network.py index d70f14ad3..bbf381f99 100644 --- a/train_network.py +++ b/train_network.py @@ -601,8 +601,10 @@ def train(self, args): # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory - logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") - unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above + # logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") + # unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above + logger.info(f"set U-Net weight dtype to {unet_weight_dtype}") + unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator unet.requires_grad_(False) unet.to(dtype=unet_weight_dtype) From 2bb0f547d72cd0256cafebd46d0f61fbe54012ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 14 Nov 2024 19:33:12 +0900 Subject: [PATCH 234/348] update grad hook creation to fix TE lr in sd3 fine tuning --- flux_train.py | 19 ++++++++++++------- library/train_util.py | 1 + sd3_train.py | 15 +++++++++------ 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/flux_train.py b/flux_train.py index ad2c7722b..a89e2f139 100644 --- a/flux_train.py +++ b/flux_train.py @@ -80,7 +80,9 @@ def train(args): assert ( args.blocks_to_swap is None or args.blocks_to_swap == 0 - ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) or not args.cpu_offload_checkpointing, ( + "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -480,13 +482,16 @@ def train(args): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: - def grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook - parameter.register_post_accumulate_grad_hook(grad_hook) + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers diff --git a/library/train_util.py b/library/train_util.py index e1dfeecdb..25cf7640d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5913,6 +5913,7 @@ def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): names.append("unet") names.append("text_encoder1") names.append("text_encoder2") + names.append("text_encoder3") # SD3 append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) diff --git a/sd3_train.py b/sd3_train.py index a4fc2eec8..96ec951b9 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -606,13 +606,16 @@ def train(args): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: - def grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook - parameter.register_post_accumulate_grad_hook(grad_hook) + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers From 5c5b544b91ac434c12a372cbf1dc123a367ec878 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 14 Nov 2024 19:35:43 +0900 Subject: [PATCH 235/348] refactor: remove unused prepare_split_model method from FluxNetworkTrainer --- flux_train_network.py | 39 --------------------------------------- 1 file changed, 39 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 9bcd59282..704c4d32e 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -127,45 +127,6 @@ def load_target_model(self, args, weight_dtype, accelerator): return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model - """ - def prepare_split_model(self, model, weight_dtype, accelerator): - from accelerate import init_empty_weights - - logger.info("prepare split model") - with init_empty_weights(): - flux_upper = flux_models.FluxUpper(model.params) - flux_lower = flux_models.FluxLower(model.params) - sd = model.state_dict() - - # lower (trainable) - logger.info("load state dict for lower") - flux_lower.load_state_dict(sd, strict=False, assign=True) - flux_lower.to(dtype=weight_dtype) - - # upper (frozen) - logger.info("load state dict for upper") - flux_upper.load_state_dict(sd, strict=False, assign=True) - - logger.info("prepare upper model") - target_dtype = torch.float8_e4m3fn if args.fp8_base else weight_dtype - flux_upper.to(accelerator.device, dtype=target_dtype) - flux_upper.eval() - - if args.fp8_base: - # this is required to run on fp8 - flux_upper = accelerator.prepare(flux_upper) - - flux_upper.to("cpu") - - self.flux_upper = flux_upper - del model # we don't need model anymore - clean_memory_on_device(accelerator.device) - - logger.info("split model prepared") - - return flux_lower - """ - def get_tokenize_strategy(self, args): _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) From fd2d879ac883b8bdf1e03b6ca545c33200dbdff2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 14 Nov 2024 19:43:08 +0900 Subject: [PATCH 236/348] docs: update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1e63b5830..81a3199bc 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ The command to install PyTorch is as follows: ### Recent Updates -Nov 12, 2024: +Nov 14, 2024: - Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM. - During fine-tuning, the memory usage when specifying the same number of blocks has increased slightly, but the training speed when specifying block swap has been significantly improved. From ccfaa001e74f80798e528b4b3ea6ef811017c07b Mon Sep 17 00:00:00 2001 From: minux302 Date: Fri, 15 Nov 2024 20:21:28 +0900 Subject: [PATCH 237/348] add flux controlnet base module --- flux_train_control_net.py | 573 ++++++++++++++++++++++++++++++++++++++ flux_train_network.py | 5 +- library/flux_models.py | 257 ++++++++++++++++- library/flux_utils.py | 8 + 4 files changed, 841 insertions(+), 2 deletions(-) create mode 100644 flux_train_control_net.py diff --git a/flux_train_control_net.py b/flux_train_control_net.py new file mode 100644 index 000000000..704c4d32e --- /dev/null +++ b/flux_train_control_net.py @@ -0,0 +1,573 @@ +import argparse +import copy +import math +import random +from typing import Any, Optional + +import torch +from accelerate import Accelerator +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util +import train_network +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class FluxNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + self.is_schnell: Optional[bool] = None + self.is_swapping_blocks: bool = False + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # prepare CLIP-L/T5XXL training flags + self.train_clip_l = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + # deprecated split_mode option + if args.split_mode: + if args.blocks_to_swap is not None: + logger.warning( + "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored." + " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。" + ) + else: + logger.warning( + "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set." + " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。" + ) + args.blocks_to_swap = 18 # 18 is safe for most cases + + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future + self.is_schnell, model = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + ) + if args.fp8_base: + # check dtype of model + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 FLUX model") + else: + logger.info( + "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + model.to(torch.float8_e4m3fn) + + # if args.split_mode: + # model = self.prepare_split_model(model, weight_dtype, accelerator) + + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + model.enable_block_swap(args.blocks_to_swap, accelerator.device) + + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + clip_l.eval() + + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + + def get_tokenize_strategy(self, args): + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + if args.cache_text_encoder_outputs: + if self.train_clip_l and not self.train_t5xxl: + return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached + else: + return None # no text encoders are needed for encoding because both are cached + else: + return text_encoders # both CLIP-L and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_clip_l, self.train_t5xxl] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_clip_l or self.train_t5xxl, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device) + + if text_encoders[1].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[1].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[1].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device) + + # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # # get size embeddings + # orig_size = batch["original_sizes_hw"] + # crop_size = batch["crop_top_lefts"] + # target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # # concat embeddings + # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + # return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): + text_encoders = text_encoder # for compatibility + text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs + ) + # return + + """ + class FluxUpperLowerWrapper(torch.nn.Module): + def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): + super().__init__() + self.flux_upper = flux_upper + self.flux_lower = flux_lower + self.target_device = device + + def prepare_block_swap_before_forward(self): + pass + + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): + self.flux_lower.to("cpu") + clean_memory_on_device(self.target_device) + self.flux_upper.to(self.target_device) + img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask) + self.flux_upper.to("cpu") + clean_memory_on_device(self.target_device) + self.flux_lower.to(self.target_device) + return self.flux_lower(img, txt, vec, pe, txt_attention_mask) + + wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) + clean_memory_on_device(accelerator.device) + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs + ) + clean_memory_on_device(accelerator.device) + """ + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images) + + def shift_scale_latents(self, args, latents): + return latents + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + ): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance + # ensure guidance_scale in args is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + img_ids.requires_grad_(True) + guidance_vec.requires_grad_(True) + + # Predict the noise residual + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): + # if not args.split_mode: + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=img, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + """ + else: + # split forward to reduce memory usage + assert network.train_blocks == "single", "train_blocks must be single for split mode" + with accelerator.autocast(): + # move flux lower to cpu, and then move flux upper to gpu + unet.to("cpu") + clean_memory_on_device(accelerator.device) + self.flux_upper.to(accelerator.device) + + # upper model does not require grad + with torch.no_grad(): + intermediate_img, intermediate_txt, vec, pe = self.flux_upper( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # move flux upper back to cpu, and then move flux lower to gpu + self.flux_upper.to("cpu") + clean_memory_on_device(accelerator.device) + unet.to(accelerator.device) + + # lower model requires grad + intermediate_img.requires_grad_(True) + intermediate_txt.requires_grad_(True) + vec.requires_grad_(True) + pe.requires_grad_(True) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + """ + + return model_pred + + model_pred = call_dit( + img=packed_noisy_model_input, + img_ids=img_ids, + t5_out=t5_out, + txt_ids=txt_ids, + l_pooled=l_pooled, + timesteps=timesteps, + guidance_vec=guidance_vec, + t5_attn_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(): + model_pred_prior = call_dit( + img=packed_noisy_model_input[diff_output_pr_indices], + img_ids=img_ids[diff_output_pr_indices], + t5_out=t5_out[diff_output_pr_indices], + txt_ids=txt_ids[diff_output_pr_indices], + l_pooled=l_pooled[diff_output_pr_indices], + timesteps=timesteps[diff_output_pr_indices], + guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, + t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) + model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + args, + model_pred_prior, + noisy_model_input[diff_output_pr_indices], + sigmas[diff_output_pr_indices] if sigmas is not None else None, + ) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, None, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + + def update_metadata(self, metadata, args): + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_guidance_scale"] = args.guidance_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0: # CLIP-L + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0: # CLIP-L + logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + flux: flux_models.Flux = unet + flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() + + return flux + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + + parser.add_argument( + "--split_mode", + action="store_true", + # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", + help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead." + " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = FluxNetworkTrainer() + trainer.train(args) diff --git a/flux_train_network.py b/flux_train_network.py index 704c4d32e..0feb9b011 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -125,7 +125,10 @@ def load_target_model(self, args, weight_dtype, accelerator): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + controlnet = flux_utils.load_controlnet() + controlnet.train() + + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model, controlnet def get_tokenize_strategy(self, args): _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) diff --git a/library/flux_models.py b/library/flux_models.py index fa3c7ad2b..a3bd19743 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1013,6 +1013,8 @@ def forward( txt_ids: Tensor, timesteps: Tensor, y: Tensor, + block_controlnet_hidden_states=None, + block_controlnet_single_hidden_states=None, guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, ) -> Tensor: @@ -1031,18 +1033,29 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) + if block_controlnet_hidden_states is not None: + controlnet_depth = len(block_controlnet_hidden_states) + if block_controlnet_single_hidden_states is not None: + controlnet_single_depth = len(block_controlnet_single_hidden_states) if not self.blocks_to_swap: - for block in self.double_blocks: + for block_idx, block in enumerate(self.double_blocks): img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_hidden_states is not None: + img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] + img = torch.cat((txt, img), 1) for block in self.single_blocks: img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_single_hidden_states is not None: + img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] else: for block_idx, block in enumerate(self.double_blocks): self.offloader_double.wait_for_block(block_idx) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_hidden_states is not None: + img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) @@ -1052,6 +1065,8 @@ def forward( self.offloader_single.wait_for_block(block_idx) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_single_hidden_states is not None: + img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) @@ -1066,6 +1081,246 @@ def forward( return img +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNetFlux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams, controlnet_depth=2): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(0) # TMP + ] + ) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.blocks_to_swap = None + + self.offloader_double = None + self.offloader_single = None + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) + + # add ControlNet blocks + self.controlnet_blocks_for_double = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks_for_double.append(controlnet_block) + self.controlnet_blocks_for_single = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks_for_single.append(controlnet_block) + self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.gradient_checkpointing = False + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(nn.Conv2d(16, 16, 3, padding=1)) + ) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) + + print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def enable_block_swap(self, num_blocks: int, device: torch.device): + self.blocks_to_swap = num_blocks + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + + assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( + f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + + self.offloader_double = custom_offloading_utils.ModelOffloader( + self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True + ) + self.offloader_single = custom_offloading_utils.ModelOffloader( + self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True + ) + print( + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." + ) + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_double_blocks = self.double_blocks + save_single_blocks = self.single_blocks + self.double_blocks = None + self.single_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.double_blocks = save_double_blocks + self.single_blocks = save_single_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, + ) -> tuple[tuple[Tensor]]: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + block_samples = () + block_single_samples = () + if not self.blocks_to_swap: + for block_idx, block in enumerate(self.double_blocks): + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_samples = block_samples + (img,) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_single_samples = block_single_samples + (img,) + else: + for block_idx, block in enumerate(self.double_blocks): + self.offloader_double.wait_for_block(block_idx) + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_samples = block_samples + (img,) + + self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) + + img = torch.cat((txt, img), 1) + + for block_idx, block in enumerate(self.single_blocks): + self.offloader_single.wait_for_block(block_idx) + + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_single_samples = block_single_samples + (img,) + + self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) + + controlnet_block_samples = () + controlnet_single_block_samples = () + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double): + block_sample = controlnet_block(block_sample) + controlnet_block_samples = controlnet_block_samples + (block_sample,) + for block_sample, controlnet_block in zip(block_samples, self.controlnet_single_blocks_for_single): + block_sample = controlnet_block(block_sample) + controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,) + + return controlnet_block_samples, controlnet_single_block_samples + + """ class FluxUpper(nn.Module): "" diff --git a/library/flux_utils.py b/library/flux_utils.py index f3093615d..678efbc8a 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -153,6 +153,14 @@ def load_ae( return ae +def load_controlnet(name, device, transformer=None): + with torch.device(device): + controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) + if transformer is not None: + controlnet.load_state_dict(transformer.state_dict(), strict=False) + return controlnet + + def load_clip_l( ckpt_path: Optional[str], dtype: torch.dtype, From 42f6edf3a886287b99770bc7a8c0bafd3fa03f39 Mon Sep 17 00:00:00 2001 From: minux302 Date: Fri, 15 Nov 2024 23:48:51 +0900 Subject: [PATCH 238/348] fix for adding controlnet --- flux_train_control_net.py | 1270 +++++++++++++++++++++-------------- flux_train_network.py | 3 - library/flux_train_utils.py | 32 +- library/flux_utils.py | 11 +- 4 files changed, 820 insertions(+), 496 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 704c4d32e..8a7be75f2 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -1,563 +1,860 @@ +# training with captions + +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + import argparse +from concurrent.futures import ThreadPoolExecutor import copy import math -import random -from typing import Any, Optional +import os +from multiprocessing import Value +import time +from typing import List, Optional, Tuple, Union +import toml + +from tqdm import tqdm import torch -from accelerate import Accelerator +import torch.nn as nn +from library import utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() -from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util -import train_network -from library.utils import setup_logging +from accelerate.utils import set_seed +from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments setup_logging() import logging logger = logging.getLogger(__name__) +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True -class FluxNetworkTrainer(train_network.NetworkTrainer): - def __init__(self): - super().__init__() - self.sample_prompts_te_outputs = None - self.is_schnell: Optional[bool] = None - self.is_swapping_blocks: bool = False + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) - # sdxl_train_util.verify_sdxl_training_args(args) + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, ( + "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) - if args.fp8_base_unet: - args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None - if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: - logger.warning( - "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" - ) - args.cache_text_encoder_outputs = True + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + if args.debug_dataset: if args.cache_text_encoder_outputs: - assert ( - train_dataset_group.is_text_encoder_output_cacheable() - ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - - # prepare CLIP-L/T5XXL training flags - self.train_clip_l = not args.network_train_unet_only - self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False + ) + ) + t5xxl_max_token_length = ( + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) + ) + strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) + + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return - if args.max_token_length is not None: - logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + if args.cache_text_encoder_outputs: assert ( - args.blocks_to_swap is None or args.blocks_to_swap == 0 - ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - # deprecated split_mode option - if args.split_mode: - if args.blocks_to_swap is not None: - logger.warning( - "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored." - " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。" - ) - else: - logger.warning( - "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set." - " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。" - ) - args.blocks_to_swap = 18 # 18 is safe for most cases + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) - train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) - def load_target_model(self, args, weight_dtype, accelerator): - # currently offload to cpu for some models + # モデルを読み込む - # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) - loading_dtype = None if args.fp8_base else weight_dtype + # load VAE for caching latents + ae = None + if cache_latents: + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() - # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future - self.is_schnell, model = flux_utils.load_flow_model( - args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + train_dataset_group.new_cache_latents(ae, accelerator) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) + + # load clip_l, t5xxl for caching text encoder outputs + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + clip_l.eval() + t5xxl.eval() + clip_l.requires_grad_(False) + t5xxl.requires_grad_(False) + + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask ) - if args.fp8_base: - # check dtype of model - if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: - raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") - elif model.dtype == torch.float8_e4m3fn: - logger.info("Loaded fp8 FLUX model") - else: - logger.info( - "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." - " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" - ) - model.to(torch.float8_e4m3fn) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = flux_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + clip_l = None + t5xxl = None + clean_memory_on_device(accelerator.device) - # if args.split_mode: - # model = self.prepare_split_model(model, weight_dtype, accelerator) + # load FLUX + _, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) + flux.requires_grad_(False) - self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 - if self.is_swapping_blocks: - # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. - logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") - model.enable_block_swap(args.blocks_to_swap, accelerator.device) + # load controlnet + controlnet = flux_utils.load_controlnet() + controlnet.requires_grad_(True) - clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - clip_l.eval() + if args.gradient_checkpointing: + controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) - # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) - if args.fp8_base and not args.fp8_base_unet: - loading_dtype = None # as is - else: - loading_dtype = weight_dtype + # block swap - # loading t5xxl to cpu takes a long time, so we should load to gpu in future - t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - t5xxl.eval() - if args.fp8_base and not args.fp8_base_unet: - # check dtype of model - if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: - raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") - elif t5xxl.dtype == torch.float8_e4m3fn: - logger.info("Loaded fp8 T5XXL model") + # backward compatibility + if args.blocks_to_swap is None: + blocks_to_swap = args.double_blocks_to_swap or 0 + if args.single_blocks_to_swap is not None: + blocks_to_swap += args.single_blocks_to_swap // 2 + if blocks_to_swap > 0: + logger.warning( + "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + ) + logger.info( + f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + ) + args.blocks_to_swap = blocks_to_swap + del blocks_to_swap + + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + flux.enable_block_swap(args.blocks_to_swap, accelerator.device) + controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) + + if not cache_latents: + # load VAE here if not cached + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(controlnet) + name_and_params = list(controlnet.named_parameters()) + # single param group for now + params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate}) + param_names = [[n for n, _ in name_and_params]] + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.blockwise_fused_optimizers: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # This balances memory usage and management complexity. + + # split params into groups. currently different learning rates are not supported + grouped_params = [] + param_group = {} + for group in params_to_optimize: + named_parameters = list(controlnet.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "single" + else: + block_index = -1 + + param_group_key = (block_type, block_index) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) - ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) - def get_tokenize_strategy(self, args): - _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + # lr schedulerを用意する + if args.blockwise_fused_optimizers: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - if args.t5xxl_max_token_length is None: - if is_schnell: - t5xxl_max_token_length = 256 - else: - t5xxl_max_token_length = 512 - else: - t5xxl_max_token_length = args.t5xxl_max_token_length + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + flux.to(weight_dtype) + controlnet.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + flux.to(weight_dtype) + controlnet.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] - logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") - return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + else: + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + controlnet = accelerator.prepare(controlnet, device_placement=[not is_swapping_blocks]) + if is_swapping_blocks: + accelerator.unwrap_model(controlnet).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook + + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) + + elif args.blockwise_fused_optimizers: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) - def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): - return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] + if is_swapping_blocks: + accelerator.unwrap_model(controlnet).prepare_block_swap_before_forward() - def get_latents_caching_strategy(self, args): - latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) - return latents_caching_strategy + # For --sample_at_first + optimizer_eval_fn() + flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + optimizer_train_fn() + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) - def get_text_encoding_strategy(self, args): - return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 - def post_process_network(self, args, accelerator, network, text_encoders, unet): - # check t5xxl is trained or not - self.train_t5xxl = network.train_t5xxl + for m in training_models: + m.train() - if self.train_t5xxl and args.cache_text_encoder_outputs: - raise ValueError( - "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" - ) + for step, batch in enumerate(train_dataloader): + current_step.value = global_step - def get_models_for_text_encoding(self, args, accelerator, text_encoders): - if args.cache_text_encoder_outputs: - if self.train_clip_l and not self.train_t5xxl: - return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached - else: - return None # no text encoders are needed for encoding because both are cached - else: - return text_encoders # both CLIP-L and T5XXL are needed for encoding + if args.blockwise_fused_optimizers: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step - def get_text_encoders_train_flags(self, args, text_encoders): - return [self.train_clip_l, self.train_t5xxl] + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] - def get_text_encoder_outputs_caching_strategy(self, args): - if args.cache_text_encoder_outputs: - # if the text encoders is trained, we need tokenization, so is_partial is True - return strategy_flux.FluxTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, - args.text_encoder_batch_size, - args.skip_cache_check, - is_partial=self.train_clip_l or self.train_t5xxl, - apply_t5_attn_mask=args.apply_t5_attn_mask, - ) - else: - return None + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps - def cache_text_encoder_outputs_if_needed( - self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype - ): - if args.cache_text_encoder_outputs: - if not args.lowram: - # メモリ消費を減らす - logger.info("move vae and unet to cpu to save memory") - org_vae_device = vae.device - org_unet_device = unet.device - vae.to("cpu") - unet.to("cpu") - clean_memory_on_device(accelerator.device) - - # When TE is not be trained, it will not be prepared so we need to use explicit autocast - logger.info("move text encoders to gpu") - text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 - text_encoders[1].to(accelerator.device) - - if text_encoders[1].dtype == torch.float8_e4m3fn: - # if we load fp8 weights, the model is already fp8, so we use it as is - self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) - else: - # otherwise, we need to convert it to target dtype - text_encoders[1].to(weight_dtype) - - with accelerator.autocast(): - dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) - - # cache sample prompts - if args.sample_prompts is not None: - logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") - - tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() - text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - - prompts = train_util.load_prompts(args.sample_prompts) - sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs - with accelerator.autocast(), torch.no_grad(): - for prompt_dict in prompts: - for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: - if p not in sample_prompts_te_outputs: - logger.info(f"cache Text Encoder outputs for prompt: {p}") - tokens_and_masks = tokenize_strategy.tokenize(p) - sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask - ) - self.sample_prompts_te_outputs = sample_prompts_te_outputs - - accelerator.wait_for_everyone() - - # move back to cpu - if not self.is_train_text_encoder(args): - logger.info("move CLIP-L back to cpu") - text_encoders[0].to("cpu") - logger.info("move t5XXL back to cpu") - text_encoders[1].to("cpu") - clean_memory_on_device(accelerator.device) - - if not args.lowram: - logger.info("move vae and unet back to original device") - vae.to(org_vae_device) - unet.to(org_unet_device) - else: - # Text Encoderから毎回出力を取得するので、GPUに乗せておく - text_encoders[0].to(accelerator.device, dtype=weight_dtype) - text_encoders[1].to(accelerator.device) + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] - # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) - # # get size embeddings - # orig_size = batch["original_sizes_hw"] - # crop_size = batch["crop_top_lefts"] - # target_size = batch["target_sizes_hw"] - # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) - # # concat embeddings - # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds - # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) - # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) - # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - # return noise_pred + # call model + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None - def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): - text_encoders = text_encoder # for compatibility - text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + with accelerator.autocast(): + block_samples, block_single_samples = controlnet( + img=packed_noisy_model_input, + img_ids=img_ids, + controlnet_cond=batch["control_image"].to(accelerator.device), + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) - flux_train_utils.sample_images( - accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs - ) - # return - - """ - class FluxUpperLowerWrapper(torch.nn.Module): - def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): - super().__init__() - self.flux_upper = flux_upper - self.flux_lower = flux_lower - self.target_device = device - - def prepare_block_swap_before_forward(self): - pass - - def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): - self.flux_lower.to("cpu") - clean_memory_on_device(self.target_device) - self.flux_upper.to(self.target_device) - img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask) - self.flux_upper.to("cpu") - clean_memory_on_device(self.target_device) - self.flux_lower.to(self.target_device) - return self.flux_lower(img, txt, vec, pe, txt_attention_mask) - - wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) - clean_memory_on_device(accelerator.device) - flux_train_utils.sample_images( - accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs - ) - clean_memory_on_device(accelerator.device) - """ - - def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) - self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) - return noise_scheduler - - def encode_images_to_latents(self, args, accelerator, vae, images): - return vae.encode(images) - - def shift_scale_latents(self, args, latents): - return latents - - def get_noise_pred_and_target( - self, - args, - accelerator, - noise_scheduler, - latents, - batch, - text_encoder_conds, - unet: flux_models.Flux, - network, - weight_dtype, - train_unet, - ): - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - - # get noisy model input and timesteps - noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, accelerator.device, weight_dtype - ) + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) - # pack latents and get img_ids - packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 - packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) - - # get guidance - # ensure guidance_scale in args is float - guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) - - # ensure the hidden state will require grad - if args.gradient_checkpointing: - noisy_model_input.requires_grad_(True) - for t in text_encoder_conds: - if t is not None and t.dtype.is_floating_point: - t.requires_grad_(True) - img_ids.requires_grad_(True) - guidance_vec.requires_grad_(True) - - # Predict the noise residual - l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds - if not args.apply_t5_attn_mask: - t5_attn_mask = None - - def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): - # if not args.split_mode: - # normal forward - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = unet( - img=img, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + loss = train_util.conditional_loss( + model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + ) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + optimizer_eval_fn() + flux_train_utils.sample_images( + accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) - """ - else: - # split forward to reduce memory usage - assert network.train_blocks == "single", "train_blocks must be single for split mode" - with accelerator.autocast(): - # move flux lower to cpu, and then move flux upper to gpu - unet.to("cpu") - clean_memory_on_device(accelerator.device) - self.flux_upper.to(accelerator.device) - # upper model does not require grad - with torch.no_grad(): - intermediate_img, intermediate_txt, vec, pe = self.flux_upper( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), ) + optimizer_train_fn() - # move flux upper back to cpu, and then move flux lower to gpu - self.flux_upper.to("cpu") - clean_memory_on_device(accelerator.device) - unet.to(accelerator.device) - - # lower model requires grad - intermediate_img.requires_grad_(True) - intermediate_txt.requires_grad_(True) - vec.requires_grad_(True) - pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) - """ - - return model_pred - - model_pred = call_dit( - img=packed_noisy_model_input, - img_ids=img_ids, - t5_out=t5_out, - txt_ids=txt_ids, - l_pooled=l_pooled, - timesteps=timesteps, - guidance_vec=guidance_vec, - t5_attn_mask=t5_attn_mask, - ) + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) - # unpack latents - model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - - # apply model prediction type - model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) - - # flow matching loss: this is different from SD3 - target = noise - latents - - # differential output preservation - if "custom_attributes" in batch: - diff_output_pr_indices = [] - for i, custom_attributes in enumerate(batch["custom_attributes"]): - if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: - diff_output_pr_indices.append(i) - - if len(diff_output_pr_indices) > 0: - network.set_multiplier(0.0) - with torch.no_grad(): - model_pred_prior = call_dit( - img=packed_noisy_model_input[diff_output_pr_indices], - img_ids=img_ids[diff_output_pr_indices], - t5_out=t5_out[diff_output_pr_indices], - txt_ids=txt_ids[diff_output_pr_indices], - l_pooled=l_pooled[diff_output_pr_indices], - timesteps=timesteps[diff_output_pr_indices], - guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, - t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, - ) - network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + accelerator.log(logs, step=global_step) - model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) - model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + optimizer_eval_fn() + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( args, - model_pred_prior, - noisy_model_input[diff_output_pr_indices], - sigmas[diff_output_pr_indices] if sigmas is not None else None, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), ) - target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - - return model_pred, target, timesteps, None, weighting - - def post_process_loss(self, loss, args, timesteps, noise_scheduler): - return loss - - def get_sai_model_spec(self, args): - return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") - - def update_metadata(self, metadata, args): - metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask - metadata["ss_weighting_scheme"] = args.weighting_scheme - metadata["ss_logit_mean"] = args.logit_mean - metadata["ss_logit_std"] = args.logit_std - metadata["ss_mode_scale"] = args.mode_scale - metadata["ss_guidance_scale"] = args.guidance_scale - metadata["ss_timestep_sampling"] = args.timestep_sampling - metadata["ss_sigmoid_scale"] = args.sigmoid_scale - metadata["ss_model_prediction_type"] = args.model_prediction_type - metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift - - def is_text_encoder_not_needed_for_training(self, args): - return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) - - def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): - if index == 0: # CLIP-L - return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) - else: # T5XXL - text_encoder.encoder.embed_tokens.requires_grad_(True) - - def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): - if index == 0: # CLIP-L - logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") - text_encoder.to(te_weight_dtype) # fp8 - text_encoder.text_model.embeddings.to(dtype=weight_dtype) - else: # T5XXL - - def prepare_fp8(text_encoder, target_dtype): - def forward_hook(module): - def forward(hidden_states): - hidden_gelu = module.act(module.wi_0(hidden_states)) - hidden_linear = module.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = module.dropout(hidden_states) - - hidden_states = module.wo(hidden_states) - return hidden_states - - return forward - - for module in text_encoder.modules(): - if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: - # print("set", module.__class__.__name__, "to", target_dtype) - module.to(target_dtype) - if module.__class__.__name__ in ["T5DenseGatedActDense"]: - # print("set", module.__class__.__name__, "hooks") - module.forward = forward_hook(module) - - if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: - logger.info(f"T5XXL already prepared for fp8") - else: - logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") - text_encoder.to(te_weight_dtype) # fp8 - prepare_fp8(text_encoder, weight_dtype) - def prepare_unet_with_accelerator( - self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module - ) -> torch.nn.Module: - if not self.is_swapping_blocks: - return super().prepare_unet_with_accelerator(args, accelerator, unet) + flux_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + ) + optimizer_train_fn() + + is_main_process = accelerator.is_main_process + # if is_main_process: + controlnet = accelerator.unwrap_model(controlnet) - # if we doesn't swap blocks, we can move the model to device - flux: flux_models.Flux = unet - flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) - accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage - accelerator.unwrap_model(flux).prepare_block_swap_before_forward() + accelerator.end_training() + optimizer_eval_fn() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す - return flux + if is_main_process: + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux) + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: - parser = train_network.setup_parser() + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) parser.add_argument( - "--split_mode", + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--single_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--cpu_offload_checkpointing", action="store_true", - # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" - # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", - help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead." - " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) return parser @@ -569,5 +866,4 @@ def setup_parser() -> argparse.ArgumentParser: train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) - trainer = FluxNetworkTrainer() - trainer.train(args) + train(args) diff --git a/flux_train_network.py b/flux_train_network.py index 0feb9b011..6668012e4 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -125,9 +125,6 @@ def load_target_model(self, args, weight_dtype, accelerator): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - controlnet = flux_utils.load_controlnet() - controlnet.train() - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model, controlnet def get_tokenize_strategy(self, args): diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index d90644a25..cc3bcb0ec 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -40,6 +40,7 @@ def sample_images( text_encoders, sample_prompts_te_outputs, prompt_replacement=None, + controlnet=None ): if steps == 0: if not args.sample_at_first: @@ -67,6 +68,8 @@ def sample_images( flux = accelerator.unwrap_model(flux) if text_encoders is not None: text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + if controlnet is not None: + controlnet = accelerator.unwrap_model(controlnet) # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) prompts = train_util.load_prompts(args.sample_prompts) @@ -98,6 +101,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, + controlnet ) else: # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) @@ -121,6 +125,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, + controlnet ) torch.set_rng_state(rng_state) @@ -142,6 +147,7 @@ def sample_image_inference( steps, sample_prompts_te_outputs, prompt_replacement, + controlnet ): assert isinstance(prompt_dict, dict) # negative_prompt = prompt_dict.get("negative_prompt") @@ -150,7 +156,7 @@ def sample_image_inference( height = prompt_dict.get("height", 512) scale = prompt_dict.get("scale", 3.5) seed = prompt_dict.get("seed") - # controlnet_image = prompt_dict.get("controlnet_image") + controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) @@ -169,6 +175,9 @@ def sample_image_inference( # if negative_prompt is None: # negative_prompt = "" + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) height = max(64, height - height % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16 @@ -224,7 +233,7 @@ def sample_image_inference( t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None with accelerator.autocast(), torch.no_grad(): - x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask) + x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) x = x.float() x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) @@ -301,18 +310,37 @@ def denoise( timesteps: list[float], guidance: float = 4.0, t5_attn_mask: Optional[torch.Tensor] = None, + controlnet: Optional[flux_models.ControlNetFlux] = None, + controlnet_img: Optional[torch.Tensor] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() + if controlnet is not None: + block_samples, block_single_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_img, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + else: + block_samples = None + block_single_samples = None pred = model( img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, timesteps=t_vec, guidance=guidance_vec, txt_attention_mask=t5_attn_mask, diff --git a/library/flux_utils.py b/library/flux_utils.py index 678efbc8a..7b538d133 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -153,11 +153,14 @@ def load_ae( return ae -def load_controlnet(name, device, transformer=None): - with torch.device(device): +def load_controlnet(): + # TODO + is_schnell = False + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + with torch.device("meta"): controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) - if transformer is not None: - controlnet.load_state_dict(transformer.state_dict(), strict=False) + # if transformer is not None: + # controlnet.load_state_dict(transformer.state_dict(), strict=False) return controlnet From e358b118afbc93f63dbb5ab6d2412ec553ea9cd7 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sat, 16 Nov 2024 14:49:29 +0900 Subject: [PATCH 239/348] fix dataloader --- flux_train_control_net.py | 84 ++++++++++++++++++++------------------- library/flux_models.py | 17 ++++---- 2 files changed, 52 insertions(+), 49 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 8a7be75f2..ee4d0ebf3 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -11,31 +11,36 @@ # - Per-block fused optimizer instances import argparse -from concurrent.futures import ThreadPoolExecutor import copy import math import os -from multiprocessing import Value import time +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Value from typing import List, Optional, Tuple, Union -import toml - -from tqdm import tqdm +import toml import torch import torch.nn as nn +from tqdm import tqdm + from library import utils -from library.device_utils import init_ipex, clean_memory_on_device +from library.device_utils import clean_memory_on_device, init_ipex init_ipex() from accelerate.utils import set_seed -from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux -from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler import library.train_util as train_util - -from library.utils import setup_logging, add_logging_arguments +from library import ( + deepspeed_utils, + flux_train_utils, + flux_utils, + strategy_base, + strategy_flux, +) +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler +from library.utils import add_logging_arguments, setup_logging setup_logging() import logging @@ -46,10 +51,10 @@ # import library.sdxl_train_util as sdxl_train_util from library.config_util import ( - ConfigSanitizer, BlueprintGenerator, + ConfigSanitizer, ) -from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments +from library.custom_train_functions import add_custom_train_arguments, apply_masked_loss def train(args): @@ -85,7 +90,6 @@ def train(args): ) cache_latents = args.cache_latents - use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する @@ -103,7 +107,7 @@ def train(args): if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "conditioing_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( @@ -111,31 +115,17 @@ def train(args): ) ) else: - if use_dreambooth_method: - logger.info("Using DreamBooth method.") - user_config = { - "datasets": [ - { - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( - args.train_data_dir, args.reg_data_dir - ) - } - ] - } - else: - logger.info("Training with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension + ) + } + ] + } blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) @@ -648,12 +638,12 @@ def grad_hook(parameter: torch.Tensor): l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds if not args.apply_t5_attn_mask: t5_attn_mask = None - + with accelerator.autocast(): block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, img_ids=img_ids, - controlnet_cond=batch["control_image"].to(accelerator.device), + controlnet_img=batch["conditioing_image"].to(accelerator.device), txt=t5_out, txt_ids=txt_ids, y=l_pooled, @@ -856,6 +846,18 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="controlnet model name or path / controlnetのモデル名またはパス", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index a3bd19743..b52ea6f0b 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -2,15 +2,15 @@ # license: Apache-2.0 License -from concurrent.futures import Future, ThreadPoolExecutor -from dataclasses import dataclass import math import os import time +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass from typing import Dict, List, Optional, Union from library import utils -from library.device_utils import init_ipex, clean_memory_on_device +from library.device_utils import clean_memory_on_device, init_ipex init_ipex() @@ -18,6 +18,7 @@ from einops import rearrange from torch import Tensor, nn from torch.utils.checkpoint import checkpoint + from library import custom_offloading_utils # USE_REENTRANT = True @@ -1251,7 +1252,7 @@ def forward( self, img: Tensor, img_ids: Tensor, - controlnet_cond: Tensor, + controlnet_img: Tensor, txt: Tensor, txt_ids: Tensor, timesteps: Tensor, @@ -1264,10 +1265,10 @@ def forward( # running on sequences img img = self.img_in(img) - controlnet_cond = self.input_hint_block(controlnet_cond) - controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - controlnet_cond = self.pos_embed_input(controlnet_cond) - img = img + controlnet_cond + controlnet_img = self.input_hint_block(controlnet_img) + controlnet_img = rearrange(controlnet_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_img = self.pos_embed_input(controlnet_img) + img = img + controlnet_img vec = self.time_in(timestep_embedding(timesteps, 256)) if self.params.guidance_embed: if guidance is None: From 2a188f07e682ed5dd958821a223d48c17a9aeb83 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 17 Nov 2024 16:12:10 +0900 Subject: [PATCH 240/348] Fix to work DOP with bock swap --- flux_train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flux_train_network.py b/flux_train_network.py index 704c4d32e..679db62b6 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -445,6 +445,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t if len(diff_output_pr_indices) > 0: network.set_multiplier(0.0) + unet.prepare_block_swap_before_forward() with torch.no_grad(): model_pred_prior = call_dit( img=packed_noisy_model_input[diff_output_pr_indices], From b2660bbe7410d7ffa40906a7a09f84a17139cb46 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sun, 17 Nov 2024 10:24:57 +0000 Subject: [PATCH 241/348] train run --- flux_train_control_net.py | 39 ++++++++++++++++++++++--------------- library/flux_models.py | 30 ++++++++++++++-------------- library/flux_train_utils.py | 2 +- library/flux_utils.py | 2 +- 4 files changed, 40 insertions(+), 33 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index ee4d0ebf3..205ff6b6a 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -103,11 +103,11 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "conditioing_data_dir"] + ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( @@ -263,10 +263,11 @@ def train(args): args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) flux.requires_grad_(False) + flux.to(accelerator.device) # load controlnet controlnet = flux_utils.load_controlnet() - controlnet.requires_grad_(True) + controlnet.train() if args.gradient_checkpointing: controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) @@ -443,7 +444,8 @@ def train(args): clean_memory_on_device(accelerator.device) - if args.deepspeed: + # if args.deepspeed: + if True: ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -612,8 +614,10 @@ def grad_hook(parameter: torch.Tensor): text_encoder_conds = text_encoding_strategy.encode_tokens( flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask ) - if args.full_fp16: - text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + # if args.full_fp16: + # text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + # TODO: check + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps @@ -629,10 +633,10 @@ def grad_hook(parameter: torch.Tensor): # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device).to(weight_dtype) # get guidance: ensure args.guidance_scale is float - guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype) # call model l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds @@ -640,10 +644,11 @@ def grad_hook(parameter: torch.Tensor): t5_attn_mask = None with accelerator.autocast(): + print("control start") block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, img_ids=img_ids, - controlnet_img=batch["conditioing_image"].to(accelerator.device), + controlnet_cond=batch["conditioning_images"].to(accelerator.device).to(weight_dtype), txt=t5_out, txt_ids=txt_ids, y=l_pooled, @@ -651,6 +656,8 @@ def grad_hook(parameter: torch.Tensor): guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) + print("control end") + print("dit start") # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = flux( img=packed_noisy_model_input, @@ -796,7 +803,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) # TODO split this - train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) train_util.add_masked_loss_arguments(parser) deepspeed_utils.add_deepspeed_arguments(parser) @@ -852,12 +859,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="controlnet model name or path / controlnetのモデル名またはパス", ) - parser.add_argument( - "--conditioning_data_dir", - type=str, - default=None, - help="conditioning data directory / 条件付けデータのディレクトリ", - ) + # parser.add_argument( + # "--conditioning_data_dir", + # type=str, + # default=None, + # help="conditioning data directory / 条件付けデータのディレクトリ", + # ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index b52ea6f0b..2fc21db9d 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1042,20 +1042,20 @@ def forward( if not self.blocks_to_swap: for block_idx, block in enumerate(self.double_blocks): img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_controlnet_hidden_states is not None: + if block_controlnet_hidden_states is not None and controlnet_depth > 0: img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] img = torch.cat((txt, img), 1) - for block in self.single_blocks: + for block_idx, block in enumerate(self.single_blocks): img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_controlnet_single_hidden_states is not None: + if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0: img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] else: for block_idx, block in enumerate(self.double_blocks): self.offloader_double.wait_for_block(block_idx) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_controlnet_hidden_states is not None: + if block_controlnet_hidden_states is not None and controlnet_depth > 0: img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) @@ -1066,7 +1066,7 @@ def forward( self.offloader_single.wait_for_block(block_idx) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_controlnet_single_hidden_states is not None: + if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0: img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) @@ -1121,14 +1121,14 @@ def __init__(self, params: FluxParams, controlnet_depth=2): mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, ) - for _ in range(params.depth) + for _ in range(controlnet_depth) ] ) self.single_blocks = nn.ModuleList( [ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) - for _ in range(0) # TMP + for _ in range(0) # TODO ] ) @@ -1148,7 +1148,7 @@ def __init__(self, params: FluxParams, controlnet_depth=2): controlnet_block = zero_module(controlnet_block) self.controlnet_blocks_for_double.append(controlnet_block) self.controlnet_blocks_for_single = nn.ModuleList([]) - for _ in range(controlnet_depth): + for _ in range(0): # TODO controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) controlnet_block = zero_module(controlnet_block) self.controlnet_blocks_for_single.append(controlnet_block) @@ -1252,7 +1252,7 @@ def forward( self, img: Tensor, img_ids: Tensor, - controlnet_img: Tensor, + controlnet_cond: Tensor, txt: Tensor, txt_ids: Tensor, timesteps: Tensor, @@ -1265,10 +1265,10 @@ def forward( # running on sequences img img = self.img_in(img) - controlnet_img = self.input_hint_block(controlnet_img) - controlnet_img = rearrange(controlnet_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - controlnet_img = self.pos_embed_input(controlnet_img) - img = img + controlnet_img + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond vec = self.time_in(timestep_embedding(timesteps, 256)) if self.params.guidance_embed: if guidance is None: @@ -1283,7 +1283,7 @@ def forward( block_samples = () block_single_samples = () if not self.blocks_to_swap: - for block_idx, block in enumerate(self.double_blocks): + for block in self.double_blocks: img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) block_samples = block_samples + (img,) @@ -1315,7 +1315,7 @@ def forward( for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double): block_sample = controlnet_block(block_sample) controlnet_block_samples = controlnet_block_samples + (block_sample,) - for block_sample, controlnet_block in zip(block_samples, self.controlnet_single_blocks_for_single): + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single): block_sample = controlnet_block(block_sample) controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index cc3bcb0ec..d82bde91c 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -460,7 +460,7 @@ def get_noisy_model_input_and_timesteps( sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents - return noisy_model_input, timesteps, sigmas + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): diff --git a/library/flux_utils.py b/library/flux_utils.py index 7b538d133..4a3817fdb 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -157,7 +157,7 @@ def load_controlnet(): # TODO is_schnell = False name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL - with torch.device("meta"): + with torch.device("cuda:0"): controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) # if transformer is not None: # controlnet.load_state_dict(transformer.state_dict(), strict=False) From 35778f021897796410372aed8540547ba317c2a3 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sun, 17 Nov 2024 11:09:05 +0000 Subject: [PATCH 242/348] fix sample_images type --- flux_train_control_net.py | 31 ++++++++++++++----------------- library/flux_train_utils.py | 2 +- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 205ff6b6a..791900d17 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -444,8 +444,7 @@ def train(args): clean_memory_on_device(accelerator.device) - # if args.deepspeed: - if True: + if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -644,7 +643,6 @@ def grad_hook(parameter: torch.Tensor): t5_attn_mask = None with accelerator.autocast(): - print("control start") block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, img_ids=img_ids, @@ -656,8 +654,6 @@ def grad_hook(parameter: torch.Tensor): guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) - print("control end") - print("dit start") # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = flux( img=packed_noisy_model_input, @@ -763,18 +759,19 @@ def grad_hook(parameter: torch.Tensor): accelerator.wait_for_everyone() optimizer_eval_fn() - if args.save_every_n_epochs is not None: - if accelerator.is_main_process: - flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( - args, - True, - accelerator, - save_dtype, - epoch, - num_train_epochs, - global_step, - accelerator.unwrap_model(flux), - ) + # TODO: save cn models + # if args.save_every_n_epochs is not None: + # if accelerator.is_main_process: + # flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + # args, + # True, + # accelerator, + # save_dtype, + # epoch, + # num_train_epochs, + # global_step, + # accelerator.unwrap_model(flux), + # ) flux_train_utils.sample_images( accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index d82bde91c..de2ee030a 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -235,7 +235,7 @@ def sample_image_inference( with accelerator.autocast(), torch.no_grad(): x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) - x = x.float() + # x = x.float() # TODO: check x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) # latent to image From 4dd4cd6ec8c55fa94b53217181ed9c95e59eed56 Mon Sep 17 00:00:00 2001 From: minux302 Date: Mon, 18 Nov 2024 12:47:01 +0000 Subject: [PATCH 243/348] work cn load and validation --- flux_train_control_net.py | 20 ++++---------------- library/flux_models.py | 6 +++--- library/flux_train_utils.py | 18 ++++++++++++++---- library/flux_utils.py | 25 ++++++++++++++++--------- 4 files changed, 37 insertions(+), 32 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 791900d17..cbfac418f 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -266,7 +266,7 @@ def train(args): flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet() + controlnet = flux_utils.load_controlnet(args.controlnet, weight_dtype, "cpu", args.disable_mmap_load_safetensors) controlnet.train() if args.gradient_checkpointing: @@ -568,7 +568,7 @@ def grad_hook(parameter: torch.Tensor): # For --sample_at_first optimizer_eval_fn() - flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet) optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb @@ -718,7 +718,7 @@ def grad_hook(parameter: torch.Tensor): optimizer_eval_fn() flux_train_utils.sample_images( - accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet ) # 指定ステップごとにモデルを保存 @@ -774,7 +774,7 @@ def grad_hook(parameter: torch.Tensor): # ) flux_train_utils.sample_images( - accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet ) optimizer_train_fn() @@ -850,18 +850,6 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) - parser.add_argument( - "--controlnet_model_name_or_path", - type=str, - default=None, - help="controlnet model name or path / controlnetのモデル名またはパス", - ) - # parser.add_argument( - # "--conditioning_data_dir", - # type=str, - # default=None, - # help="conditioning data directory / 条件付けデータのディレクトリ", - # ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index 2fc21db9d..4123b40e5 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1142,11 +1142,11 @@ def __init__(self, params: FluxParams, controlnet_depth=2): self.num_single_blocks = len(self.single_blocks) # add ControlNet blocks - self.controlnet_blocks_for_double = nn.ModuleList([]) + self.controlnet_blocks = nn.ModuleList([]) for _ in range(controlnet_depth): controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) controlnet_block = zero_module(controlnet_block) - self.controlnet_blocks_for_double.append(controlnet_block) + self.controlnet_blocks.append(controlnet_block) self.controlnet_blocks_for_single = nn.ModuleList([]) for _ in range(0): # TODO controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) @@ -1312,7 +1312,7 @@ def forward( controlnet_block_samples = () controlnet_single_block_samples = () - for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double): + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): block_sample = controlnet_block(block_sample) controlnet_block_samples = controlnet_block_samples + (block_sample,) for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single): diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index de2ee030a..dbbaba734 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -175,10 +175,6 @@ def sample_image_inference( # if negative_prompt is None: # negative_prompt = "" - if controlnet_image is not None: - controlnet_image = Image.open(controlnet_image).convert("RGB") - controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) - height = max(64, height - height % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16 logger.info(f"prompt: {prompt}") @@ -232,6 +228,12 @@ def sample_image_inference( img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) + with accelerator.autocast(), torch.no_grad(): x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) @@ -315,6 +317,8 @@ def denoise( ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() @@ -560,6 +564,12 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提", ) parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--controlnet", + type=str, + default=None, + help="path to controlnet (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)" + ) parser.add_argument( "--t5xxl_max_token_length", type=int, diff --git a/library/flux_utils.py b/library/flux_utils.py index 4a3817fdb..fb7a30749 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -153,15 +153,22 @@ def load_ae( return ae -def load_controlnet(): - # TODO - is_schnell = False - name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL - with torch.device("cuda:0"): - controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) - # if transformer is not None: - # controlnet.load_state_dict(transformer.state_dict(), strict=False) - return controlnet +def load_controlnet( + ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +): + logger.info("Building ControlNet") + # is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) + is_schnell = False + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + with torch.device("meta"): + controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype) + + if ckpt_path is not None: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = controlnet.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded ControlNet: {info}") + return controlnet def load_clip_l( From 31ca899b6b5425466c814d0d9e2e4e8bfbf93001 Mon Sep 17 00:00:00 2001 From: minux302 Date: Mon, 18 Nov 2024 13:03:28 +0000 Subject: [PATCH 244/348] fix depth value --- library/flux_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/library/flux_models.py b/library/flux_models.py index 4123b40e5..328ad481d 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1093,7 +1093,7 @@ class ControlNetFlux(nn.Module): Transformer model for flow matching on sequences. """ - def __init__(self, params: FluxParams, controlnet_depth=2): + def __init__(self, params: FluxParams, controlnet_depth=2, controlnet_single_depth=0): super().__init__() self.params = params @@ -1128,7 +1128,7 @@ def __init__(self, params: FluxParams, controlnet_depth=2): self.single_blocks = nn.ModuleList( [ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) - for _ in range(0) # TODO + for _ in range(controlnet_single_depth) ] ) @@ -1148,7 +1148,7 @@ def __init__(self, params: FluxParams, controlnet_depth=2): controlnet_block = zero_module(controlnet_block) self.controlnet_blocks.append(controlnet_block) self.controlnet_blocks_for_single = nn.ModuleList([]) - for _ in range(0): # TODO + for _ in range(controlnet_single_depth): controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) controlnet_block = zero_module(controlnet_block) self.controlnet_blocks_for_single.append(controlnet_block) From 2a61fc07846dc919ea64b568f7e18c010e5c8e06 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Wed, 20 Nov 2024 21:20:35 +0900 Subject: [PATCH 245/348] docs: fix typo from block_to_swap to blocks_to_swap in README --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 81a3199bc..f9c85e3ac 100644 --- a/README.md +++ b/README.md @@ -68,11 +68,11 @@ When training LoRA for Text Encoder (without `--network_train_unet_only`), more __Options for GPUs with less VRAM:__ -By specifying `--block_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. +By specifying `--blocks_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. -Specify a number like `--block_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks. +Specify a number like `--blocks_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks. -`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--block_to_swap`. +`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--blocks_to_swap`. Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settings like below: @@ -82,7 +82,7 @@ Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settin The training can be done with 16GB VRAM GPUs with the batch size of 1. Please change your dataset configuration. -The training can be done with 12GB VRAM GPUs with `--block_to_swap 16` with 8bit AdamW. Please use settings like below: +The training can be done with 12GB VRAM GPUs with `--blocks_to_swap 16` with 8bit AdamW. Please use settings like below: ``` --blocks_to_swap 16 From 0b5229a9550cb921b83d22472c4785a15c42ba90 Mon Sep 17 00:00:00 2001 From: minux302 Date: Thu, 21 Nov 2024 15:55:27 +0000 Subject: [PATCH 246/348] save cn --- flux_train_control_net.py | 34 +++++++++++++++------------------- library/flux_train_utils.py | 1 - 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index cbfac418f..0f38b7094 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -266,7 +266,7 @@ def train(args): flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet(args.controlnet, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, "cpu", args.disable_mmap_load_safetensors) controlnet.train() if args.gradient_checkpointing: @@ -613,9 +613,6 @@ def grad_hook(parameter: torch.Tensor): text_encoder_conds = text_encoding_strategy.encode_tokens( flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask ) - # if args.full_fp16: - # text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] - # TODO: check text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps @@ -733,7 +730,7 @@ def grad_hook(parameter: torch.Tensor): epoch, num_train_epochs, global_step, - accelerator.unwrap_model(flux), + accelerator.unwrap_model(controlnet), ) optimizer_train_fn() @@ -759,19 +756,18 @@ def grad_hook(parameter: torch.Tensor): accelerator.wait_for_everyone() optimizer_eval_fn() - # TODO: save cn models - # if args.save_every_n_epochs is not None: - # if accelerator.is_main_process: - # flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( - # args, - # True, - # accelerator, - # save_dtype, - # epoch, - # num_train_epochs, - # global_step, - # accelerator.unwrap_model(flux), - # ) + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(controlnet), + ) flux_train_utils.sample_images( accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet @@ -791,7 +787,7 @@ def grad_hook(parameter: torch.Tensor): del accelerator # この後メモリを使うのでこれは消す if is_main_process: - flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux) + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, controlnet) logger.info("model saved.") diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index dbbaba734..5e25c7feb 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -237,7 +237,6 @@ def sample_image_inference( with accelerator.autocast(), torch.no_grad(): x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) - # x = x.float() # TODO: check x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) # latent to image From 420a180d938c7b5a6e3006b1719dbfeaae72a2cc Mon Sep 17 00:00:00 2001 From: recris Date: Wed, 27 Nov 2024 18:11:51 +0000 Subject: [PATCH 247/348] Implement pseudo Huber loss for Flux and SD3 --- fine_tune.py | 6 +-- flux_train.py | 2 +- flux_train_network.py | 2 +- library/train_util.py | 74 ++++++++++++++++------------ sd3_train.py | 2 +- sd3_train_network.py | 2 +- sdxl_train.py | 6 +-- sdxl_train_control_net.py | 4 +- sdxl_train_control_net_lllite.py | 4 +- sdxl_train_control_net_lllite_old.py | 6 ++- train_controlnet.py | 6 +-- train_db.py | 4 +- train_network.py | 9 ++-- train_textual_inversion.py | 4 +- train_textual_inversion_XTI.py | 6 ++- 15 files changed, 76 insertions(+), 61 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 0090bd190..70959a751 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -380,7 +380,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -397,7 +397,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) loss = loss.mean([1, 2, 3]) @@ -411,7 +411,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) accelerator.backward(loss) diff --git a/flux_train.py b/flux_train.py index a89e2f139..f6e43b27a 100644 --- a/flux_train.py +++ b/flux_train.py @@ -667,7 +667,7 @@ def grad_hook(parameter: torch.Tensor): # calculate loss loss = train_util.conditional_loss( - model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if weighting is not None: loss = loss * weighting diff --git a/flux_train_network.py b/flux_train_network.py index 679db62b6..04287f399 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -468,7 +468,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t ) target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, None, weighting + return model_pred, target, timesteps, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/library/train_util.py b/library/train_util.py index 25cf7640d..c204ebd38 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3905,7 +3905,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--huber_c", type=float, default=0.1, - help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + ) + + parser.add_argument( + "--huber_scale", + type=float, + default=1.0, + help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", ) parser.add_argument( @@ -5821,29 +5828,10 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") - - if args.loss_type == "huber" or args.loss_type == "smooth_l1": - if args.huber_schedule == "exponential": - alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps - huber_c = torch.exp(-alpha * timesteps) - elif args.huber_schedule == "snr": - alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps) - sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 - huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c - elif args.huber_schedule == "constant": - huber_c = torch.full((b_size,), args.huber_c) - else: - raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - huber_c = huber_c.to(device) - elif args.loss_type == "l2": - huber_c = None # may be anything, as it's not used - else: - raise NotImplementedError(f"Unknown loss type {args.loss_type}") - - timesteps = timesteps.long().to(device) - return timesteps, huber_c +def get_timesteps(min_timestep, max_timestep, b_size, device): + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) + timesteps = timesteps.long() + return timesteps def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): @@ -5865,7 +5853,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device) + timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -5878,24 +5866,46 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - return noise, noisy_latents, timesteps, huber_c + return noise, noisy_latents, timesteps + + +def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor: + b_size = timesteps.shape[0] + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + result = torch.exp(-alpha * timesteps) * args.huber_scale + elif args.huber_schedule == "snr": + if not hasattr(noise_scheduler, 'alphas_cumprod'): + raise NotImplementedError(f"Huber schedule 'snr' is not supported with the current model.") + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + result = result.to(timesteps.device) + elif args.huber_schedule == "constant": + result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) + else: + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") + + return result def conditional_loss( - model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor] + args, model_pred: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor, reduction: str, noise_scheduler ): - if loss_type == "l2": + if args.loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) - elif loss_type == "l1": + elif args.loss_type == "l1": loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) - elif loss_type == "huber": + elif args.loss_type == "huber": + huber_c = get_huber_threshold(args, timesteps, noise_scheduler) huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) - elif loss_type == "smooth_l1": + elif args.loss_type == "smooth_l1": + huber_c = get_huber_threshold(args, timesteps, noise_scheduler) huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": @@ -5903,7 +5913,7 @@ def conditional_loss( elif reduction == "sum": loss = torch.sum(loss) else: - raise NotImplementedError(f"Unsupported Loss Type {loss_type}") + raise NotImplementedError(f"Unsupported Loss Type: {args.loss_type}") return loss diff --git a/sd3_train.py b/sd3_train.py index 96ec951b9..cf2bdf938 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -845,7 +845,7 @@ def grad_hook(parameter: torch.Tensor): # ) # calculate loss loss = train_util.conditional_loss( - model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) diff --git a/sd3_train_network.py b/sd3_train_network.py index 1726e325f..fb7711bda 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -378,7 +378,7 @@ def get_noise_pred_and_target( target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, None, weighting + return model_pred, target, timesteps, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/sdxl_train.py b/sdxl_train.py index e26f4aa19..1bc27ec6c 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -695,7 +695,7 @@ def optimizer_hook(parameter: torch.Tensor): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -720,7 +720,7 @@ def optimizer_hook(parameter: torch.Tensor): ): # do not mean over batch dimension for snr weight or scale v-pred loss loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) @@ -738,7 +738,7 @@ def optimizer_hook(parameter: torch.Tensor): loss = loss.mean() # mean over batch dimension else: loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) accelerator.backward(loss) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 24080afbd..d0051d18f 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -512,7 +512,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -534,7 +534,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) loss = loss.mean([1, 2, 3]) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 2946c97d4..66214f5df 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -463,7 +463,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -485,7 +485,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) loss = loss.mean([1, 2, 3]) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 2d4465234..5e10654b9 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -406,7 +406,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -426,7 +426,9 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler + ) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_controlnet.py b/train_controlnet.py index 8c7882c8f..da7a08d69 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -464,8 +464,8 @@ def remove_model(old_ckpt_name): ) # Sample a random timestep for each image - timesteps, huber_c = train_util.get_timesteps_and_huber_c( - args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device + timesteps = train_util.get_timesteps( + 0, noise_scheduler.config.num_train_timesteps, b_size, latents.device ) # Add noise to the latents according to the noise magnitude at each timestep @@ -499,7 +499,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) loss = loss.mean([1, 2, 3]) diff --git a/train_db.py b/train_db.py index 51e209f34..a185b31b3 100644 --- a/train_db.py +++ b/train_db.py @@ -370,7 +370,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -385,7 +385,7 @@ def train(args): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) diff --git a/train_network.py b/train_network.py index bbf381f99..c7d4f5dc5 100644 --- a/train_network.py +++ b/train_network.py @@ -192,7 +192,7 @@ def get_noise_pred_and_target( ): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # ensure the hidden state will require grad if args.gradient_checkpointing: @@ -244,7 +244,7 @@ def get_noise_pred_and_target( network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - return noise_pred, target, timesteps, huber_c, None + return noise_pred, target, timesteps, None def post_process_loss(self, loss, args, timesteps, noise_scheduler): if args.min_snr_gamma: @@ -806,6 +806,7 @@ def load_model_hook(models, input_dir): "ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength, "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, + "ss_huber_scale": args.huber_scale, "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), @@ -1193,7 +1194,7 @@ def remove_model(old_ckpt_name): text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target - noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( + noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( args, accelerator, noise_scheduler, @@ -1207,7 +1208,7 @@ def remove_model(old_ckpt_name): ) loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if weighting is not None: loss = loss * weighting diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 5f4657eb9..9e1e57c48 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -585,7 +585,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -602,7 +602,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 52d525fc5..944733602 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -461,7 +461,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -473,7 +473,9 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler + ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) From 740ec1d5265fa321659589ae6a75a4a9898ef8be Mon Sep 17 00:00:00 2001 From: recris Date: Thu, 28 Nov 2024 20:38:32 +0000 Subject: [PATCH 248/348] Fix issues found in review --- fine_tune.py | 2 +- library/train_util.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 70959a751..401a40f08 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -411,7 +411,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler + args, noise_pred.float(), target.float(), timesteps, "mean", noise_scheduler ) accelerator.backward(loss) diff --git a/library/train_util.py b/library/train_util.py index c204ebd38..eaf6ec004 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5829,8 +5829,8 @@ def save_sd_model_on_train_end_common( def get_timesteps(min_timestep, max_timestep, b_size, device): - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) - timesteps = timesteps.long() + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") + timesteps = timesteps.long().to(device) return timesteps @@ -5875,8 +5875,8 @@ def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps result = torch.exp(-alpha * timesteps) * args.huber_scale elif args.huber_schedule == "snr": - if not hasattr(noise_scheduler, 'alphas_cumprod'): - raise NotImplementedError(f"Huber schedule 'snr' is not supported with the current model.") + if not hasattr(noise_scheduler, "alphas_cumprod"): + raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c From 575f583fd9cbaf7f7b644a31437ed9094810b99a Mon Sep 17 00:00:00 2001 From: minux302 Date: Fri, 29 Nov 2024 23:55:52 +0900 Subject: [PATCH 249/348] add README --- README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.md b/README.md index f9c85e3ac..2b1ca3f8c 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ Nov 14, 2024: - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) - [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) - [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) +- [FLUX.1 ControlNet training](#flux1-controlnet-training) - [FLUX.1 OFT training](#flux1-oft-training) - [Inference for FLUX.1 with LoRA model](#inference-for-flux1-with-lora-model) - [FLUX.1 fine-tuning](#flux1-fine-tuning) @@ -245,6 +246,22 @@ example: If you specify one of `train_double_block_indices` or `train_single_block_indices`, the other will be trained as usual. +### FLUX.1 ControlNet training +We have added a new training script for ControlNet training. The script is flux_train_control_net.py. See --help for options. + +Sample command is below. It will work with 80GB VRAM GPUs. +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_control_net.py +--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors +--ae ae.safetensors --save_model_as safetensors --sdpa --persistent_data_loader_workers +--max_data_loader_n_workers 1 --seed 42 --gradient_checkpointing --mixed_precision bf16 +--optimizer_type adamw8bit --learning_rate 2e-5 +--highvram --max_train_epochs 1 --save_every_n_steps 1000 --dataset_config dataset.toml +--output_dir /path/to/output/dir --output_name flux-cn +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --deepspeed +``` + + ### FLUX.1 OFT training You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different. From be5860f8e266c5562f123fe9e0cb3febef615290 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sat, 30 Nov 2024 00:08:21 +0900 Subject: [PATCH 250/348] add schnell option to load_cn --- flux_train_control_net.py | 4 ++-- library/flux_utils.py | 14 ++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index a17c811e3..bb27c35ed 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -259,14 +259,14 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - _, flux = flux_utils.load_flow_model( + is_schnell, flux = flux_utils.load_flow_model( args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) flux.requires_grad_(False) flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, accelerator.device, args.disable_mmap_load_safetensors) + controlnet = flux_utils.load_controlnet(args.controlnet, is_schnell, torch.float32, accelerator.device, args.disable_mmap_load_safetensors) controlnet.train() if args.gradient_checkpointing: diff --git a/library/flux_utils.py b/library/flux_utils.py index f2759c375..8be1d63ee 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,14 +1,14 @@ -from dataclasses import replace import json import os +from dataclasses import replace from typing import List, Optional, Tuple, Union + import einops import torch - -from safetensors.torch import load_file -from safetensors import safe_open from accelerate import init_empty_weights -from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config +from safetensors import safe_open +from safetensors.torch import load_file +from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel from library.utils import setup_logging @@ -154,11 +154,9 @@ def load_ae( def load_controlnet( - ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False + ckpt_path: Optional[str], is_schnell: bool, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False ): logger.info("Building ControlNet") - # is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) - is_schnell = False name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL with torch.device(device): controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype) From f40632bac6704886a7640c327d64820f8f017df8 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sat, 30 Nov 2024 00:15:47 +0900 Subject: [PATCH 251/348] rm abundant arg --- flux_train_network.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 314335366..fa3810e34 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -6,12 +6,21 @@ import torch from accelerate import Accelerator -from library.device_utils import init_ipex, clean_memory_on_device + +from library.device_utils import clean_memory_on_device, init_ipex init_ipex() -from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util import train_network +from library import ( + flux_models, + flux_train_utils, + flux_utils, + sd3_train_utils, + strategy_base, + strategy_flux, + train_util, +) from library.utils import setup_logging setup_logging() @@ -125,7 +134,7 @@ def load_target_model(self, args, weight_dtype, accelerator): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model, controlnet + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model def get_tokenize_strategy(self, args): _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) From 928b9393daac252d0b6c4c9dd277d549b3dad8e9 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 20 Nov 2024 11:15:30 -0500 Subject: [PATCH 252/348] Allow unknown schedule-free optimizers to continue to module loader --- library/train_util.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 25cf7640d..74050880a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4600,7 +4600,7 @@ def task(): def get_optimizer(args, trainable_params): # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" - + optimizer_type = args.optimizer_type if args.use_8bit_adam: assert ( @@ -4874,6 +4874,7 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type.endswith("schedulefree".lower()): + should_train_optimizer = True try: import schedulefree as sf except ImportError: @@ -4885,10 +4886,10 @@ def get_optimizer(args, trainable_params): optimizer_class = sf.SGDScheduleFree logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") else: - raise ValueError(f"Unknown optimizer type: {optimizer_type}") - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - # make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop - optimizer.train() + optimizer_class = None + + if optimizer_class is not None: + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) if optimizer is None: # 任意のoptimizerを使う @@ -4990,6 +4991,10 @@ def __instancecheck__(self, instance): optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) + if hasattr(optimizer, 'train') and callable(optimizer.train): + # make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop + optimizer.train() + return optimizer_name, optimizer_args, optimizer From 87f5224e2d19254748158939cbca75802fc024f2 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 20 Nov 2024 11:57:15 -0500 Subject: [PATCH 253/348] Support d*lr for ProdigyPlus optimizer --- train_network.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index bbf381f99..65962bd74 100644 --- a/train_network.py +++ b/train_network.py @@ -61,6 +61,7 @@ def generate_step_logs( avr_loss, lr_scheduler, lr_descriptions, + optimizer=None, keys_scaled=None, mean_norm=None, maximum_norm=None, @@ -93,6 +94,30 @@ def generate_step_logs( logs[f"lr/d*lr/{lr_desc}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None + ): # tracking d*lr value of unet. + logs["lr/d*lr"] = ( + optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + ) + else: + idx = 0 + if not args.network_train_unet_only: + logs["lr/textencoder"] = float(lrs[0]) + idx = 1 + + for i in range(idx, len(lrs)): + logs[f"lr/group{i}"] = float(lrs[i]) + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + logs[f"lr/d*lr/group{i}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None + ): + logs[f"lr/d*lr/group{i}"] = ( + optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + ) return logs @@ -1279,7 +1304,7 @@ def remove_model(old_ckpt_name): if len(accelerator.trackers) > 0: logs = self.generate_step_logs( - args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) accelerator.log(logs, step=global_step) From 6593cfbec14c0be70407b5d6d85d569ecf8160f1 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 21 Nov 2024 14:41:37 -0500 Subject: [PATCH 254/348] Fix d * lr step log --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 65962bd74..c236a2c95 100644 --- a/train_network.py +++ b/train_network.py @@ -116,7 +116,7 @@ def generate_step_logs( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None ): logs[f"lr/d*lr/group{i}"] = ( - optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] ) return logs From c7cadbc8c73b48eaacbfb44b18121d20df373e19 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 15:52:03 -0500 Subject: [PATCH 255/348] Add pytest testing --- .github/workflows/tests.yml | 54 +++++++++++++ library/train_util.py | 4 +- pytest.ini | 7 ++ tests/test_optimizer.py | 153 ++++++++++++++++++++++++++++++++++++ 4 files changed, 216 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/tests.yml create mode 100644 pytest.ini create mode 100644 tests/test_optimizer.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 000000000..50b08243a --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,54 @@ + +name: Python package + +on: [push] + +jobs: + build: + + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: ["3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel + + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Test with pytest + run: | + pip install pytest pytest-cov + pytest --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html + + - name: Upload pytest test results + uses: actions/upload-artifact@v4 + with: + name: pytest-results-${{ matrix.python-version }} + path: junit/test-results-${{ matrix.python-version }}.xml + # Use always() to always run this step to publish test results when there are test failures + if: ${{ always() }} diff --git a/library/train_util.py b/library/train_util.py index 25cf7640d..823cd3663 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -21,7 +21,7 @@ Optional, Sequence, Tuple, - Union, + Union ) from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob @@ -4598,7 +4598,7 @@ def task(): accelerator.load_state(dirname) -def get_optimizer(args, trainable_params): +def get_optimizer(args, trainable_params) -> tuple[str, str, object]: # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..63e03efc5 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,7 @@ +[pytest] +minversion = 6.0 +testpaths = + tests +filterwarnings = + ignore::DeprecationWarning + ignore::UserWarning diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 000000000..f6ade91a6 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,153 @@ +from unittest.mock import patch +from library.train_util import get_optimizer +from train_network import setup_parser +import torch +from torch.nn import Parameter + +# Optimizer libraries +import bitsandbytes as bnb +from lion_pytorch import lion_pytorch +import schedulefree + +import dadaptation +import dadaptation.experimental as dadapt_experimental + +import prodigyopt +import schedulefree as sf +import transformers + + +def test_default_get_optimizer(): + with patch("sys.argv", [""]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param]) + assert optimizer_name == "torch.optim.adamw.AdamW" + assert optimizer_args == "" + assert isinstance(optimizer, torch.optim.AdamW) + + +def test_get_schedulefree_optimizer(): + with patch("sys.argv", ["", "--optimizer_type", "AdamWScheduleFree"]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param]) + assert optimizer_name == "schedulefree.adamw_schedulefree.AdamWScheduleFree" + assert optimizer_args == "" + assert isinstance(optimizer, schedulefree.adamw_schedulefree.AdamWScheduleFree) + + +def test_all_supported_optimizers(): + optimizers = [ + { + "name": "bitsandbytes.optim.adamw.AdamW8bit", + "alias": "AdamW8bit", + "instance": bnb.optim.AdamW8bit, + }, + { + "name": "lion_pytorch.lion_pytorch.Lion", + "alias": "Lion", + "instance": lion_pytorch.Lion, + }, + { + "name": "torch.optim.adamw.AdamW", + "alias": "AdamW", + "instance": torch.optim.AdamW, + }, + { + "name": "bitsandbytes.optim.lion.Lion8bit", + "alias": "Lion8bit", + "instance": bnb.optim.Lion8bit, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW8bit", + "alias": "PagedAdamW8bit", + "instance": bnb.optim.PagedAdamW8bit, + }, + { + "name": "bitsandbytes.optim.lion.PagedLion8bit", + "alias": "PagedLion8bit", + "instance": bnb.optim.PagedLion8bit, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW", + "alias": "PagedAdamW", + "instance": bnb.optim.PagedAdamW, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW32bit", + "alias": "PagedAdamW32bit", + "instance": bnb.optim.PagedAdamW32bit, + }, + {"name": "torch.optim.sgd.SGD", "alias": "SGD", "instance": torch.optim.SGD}, + { + "name": "dadaptation.experimental.dadapt_adam_preprint.DAdaptAdamPreprint", + "alias": "DAdaptAdamPreprint", + "instance": dadapt_experimental.DAdaptAdamPreprint, + }, + { + "name": "dadaptation.dadapt_adagrad.DAdaptAdaGrad", + "alias": "DAdaptAdaGrad", + "instance": dadaptation.DAdaptAdaGrad, + }, + { + "name": "dadaptation.dadapt_adan.DAdaptAdan", + "alias": "DAdaptAdan", + "instance": dadaptation.DAdaptAdan, + }, + { + "name": "dadaptation.experimental.dadapt_adan_ip.DAdaptAdanIP", + "alias": "DAdaptAdanIP", + "instance": dadapt_experimental.DAdaptAdanIP, + }, + { + "name": "dadaptation.dadapt_lion.DAdaptLion", + "alias": "DAdaptLion", + "instance": dadaptation.DAdaptLion, + }, + { + "name": "dadaptation.dadapt_sgd.DAdaptSGD", + "alias": "DAdaptSGD", + "instance": dadaptation.DAdaptSGD, + }, + { + "name": "prodigyopt.prodigy.Prodigy", + "alias": "Prodigy", + "instance": prodigyopt.Prodigy, + }, + { + "name": "transformers.optimization.Adafactor", + "alias": "Adafactor", + "instance": transformers.optimization.Adafactor, + }, + { + "name": "schedulefree.adamw_schedulefree.AdamWScheduleFree", + "alias": "AdamWScheduleFree", + "instance": sf.AdamWScheduleFree, + }, + { + "name": "schedulefree.sgd_schedulefree.SGDScheduleFree", + "alias": "SGDScheduleFree", + "instance": sf.SGDScheduleFree, + }, + ] + + for opt in optimizers: + with patch("sys.argv", ["", "--optimizer_type", opt.get("alias")]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, _, optimizer = get_optimizer(args, [param]) + assert optimizer_name == opt.get("name") + + instance = opt.get("instance") + assert instance is not None + assert isinstance(optimizer, instance) From 2dd063a679effae2538c474fece1e7aacad0c9c5 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 15:57:31 -0500 Subject: [PATCH 256/348] add torch torchvision accelerate versions --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 50b08243a..96ab612d8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -40,6 +40,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt + pip install torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0 - name: Test with pytest run: | pip install pytest pytest-cov From e59e276fb948a1dc8a64672d8fd6d3a7eb166c80 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 16:03:29 -0500 Subject: [PATCH 257/348] Add dadaptation --- .github/workflows/tests.yml | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 96ab612d8..433c326bf 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,7 +10,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: ["3.10", "3.11"] + python-version: ["3.10"] steps: - uses: actions/checkout@v4 @@ -26,30 +26,14 @@ jobs: uses: actions/setup-python@v5 with: python-version: '3.x' + cache: 'pip' # caching pip dependencies - name: Install dependencies run: | python -m pip install --upgrade pip pip install -r requirements.txt - - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.x' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.txt - pip install torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0 + pip install dadaptation==3.2 torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0 - name: Test with pytest run: | - pip install pytest pytest-cov - pytest --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html + pip install pytest + pytest - - name: Upload pytest test results - uses: actions/upload-artifact@v4 - with: - name: pytest-results-${{ matrix.python-version }} - path: junit/test-results-${{ matrix.python-version }}.xml - # Use always() to always run this step to publish test results when there are test failures - if: ${{ always() }} From dd3b846b54814b605bd33ae08ed480ea5075483b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 16:18:05 -0500 Subject: [PATCH 258/348] Install pytorch first to pin version --- .github/workflows/tests.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 433c326bf..9ae67b0e9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,6 +18,7 @@ jobs: uses: actions/setup-python@v5 with: python-version: '3.x' + - name: Install dependencies run: python -m pip install --upgrade pip setuptools wheel @@ -27,11 +28,13 @@ jobs: with: python-version: '3.x' cache: 'pip' # caching pip dependencies + - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt pip install dadaptation==3.2 torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0 + pip install -r requirements.txt + - name: Test with pytest run: | pip install pytest From 89825d6898ba6629b18cc8c1f9fbd93a730ff36e Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 16:27:13 -0500 Subject: [PATCH 259/348] Run typos workflows once where appropriate --- .github/workflows/typos.yml | 6 ++++-- pytest.ini | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index 0149dcdd3..667146a7a 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -1,9 +1,11 @@ --- -# yamllint disable rule:line-length name: Typos -on: # yamllint disable-line rule:truthy +on: push: + branches: + - main + - dev pull_request: types: - opened diff --git a/pytest.ini b/pytest.ini index 63e03efc5..484d3aef6 100644 --- a/pytest.ini +++ b/pytest.ini @@ -5,3 +5,4 @@ testpaths = filterwarnings = ignore::DeprecationWarning ignore::UserWarning + ignore::FutureWarning From 4f7f248071c93f539c12c8a35380b6d983bfff4c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 16:28:51 -0500 Subject: [PATCH 260/348] Bump typos action --- .github/workflows/typos.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index 667146a7a..87ebdf894 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -20,4 +20,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.24.3 + uses: crate-ci/typos@v1.28.1 From 9c885e549dbb5535b37f2a3220b5a8f53ad4d211 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 30 Nov 2024 18:25:50 +0900 Subject: [PATCH 261/348] fix: improve pos_embed handling for oversized images and update resolution_area_to_latent_size, when sample image size > train image size --- library/sd3_models.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 8b90205db..2f3c82eed 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1017,22 +1017,35 @@ def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: b patched_size = patched_size_ break if patched_size is None: - raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") + # raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") + # use largest latent size + patched_size = self.resolution_area_to_latent_size[-1][1] pos_embed = self.resolution_pos_embeds[patched_size] - pos_embed_size = round(math.sqrt(pos_embed.shape[1])) + pos_embed_size = round(math.sqrt(pos_embed.shape[1])) # max size, patched_size * POS_EMBED_MAX_RATIO if h > pos_embed_size or w > pos_embed_size: # # fallback to normal pos_embed # return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop) # extend pos_embed size logger.warning( - f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." + f"Add new pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." ) - pos_embed_size = max(h, w) - pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size) + patched_size = max(h, w) + grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed_size = grid_size + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size) pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) self.resolution_pos_embeds[patched_size] = pos_embed - logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}") + logger.info(f"Added pos_embed for size {patched_size}x{patched_size}") + + # print(torch.allclose(pos_embed.to(torch.float32).cpu(), self.pos_embed.to(torch.float32).cpu(), atol=5e-2)) + # diff = pos_embed.to(torch.float32).cpu() - self.pos_embed.to(torch.float32).cpu() + # print(diff.abs().max(), diff.abs().mean()) + + # insert to resolution_area_to_latent_size, by adding and sorting + area = pos_embed_size**2 + self.resolution_area_to_latent_size.append((area, patched_size)) + self.resolution_area_to_latent_size = sorted(self.resolution_area_to_latent_size) if not random_crop: top = (pos_embed_size - h) // 2 From 7b61e9eb58e0a004b451e8f06c9f90b861f81b45 Mon Sep 17 00:00:00 2001 From: recris Date: Sat, 30 Nov 2024 11:36:40 +0000 Subject: [PATCH 262/348] Fix issues found in review (pt 2) --- library/train_util.py | 2 +- sd3_train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index eaf6ec004..d5e72323a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5875,7 +5875,7 @@ def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps result = torch.exp(-alpha * timesteps) * args.huber_scale elif args.huber_schedule == "snr": - if not hasattr(noise_scheduler, "alphas_cumprod"): + if noise_scheduler is None or not hasattr(noise_scheduler, "alphas_cumprod"): raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 diff --git a/sd3_train.py b/sd3_train.py index cf2bdf938..909c5ead6 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -845,7 +845,7 @@ def grad_hook(parameter: torch.Tensor): # ) # calculate loss loss = train_util.conditional_loss( - args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler + args, model_pred.float(), target.float(), timesteps, "none", None ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) From 14f642f88be888ce1a4157b550186347c159ca42 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 13:30:35 +0900 Subject: [PATCH 263/348] fix: huber_schedule exponential not working on sd3_train.py --- library/train_util.py | 2 +- sd3_train.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d5e72323a..eaf6ec004 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5875,7 +5875,7 @@ def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps result = torch.exp(-alpha * timesteps) * args.huber_scale elif args.huber_schedule == "snr": - if noise_scheduler is None or not hasattr(noise_scheduler, "alphas_cumprod"): + if not hasattr(noise_scheduler, "alphas_cumprod"): raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 diff --git a/sd3_train.py b/sd3_train.py index 909c5ead6..73a68aa6a 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -675,8 +675,8 @@ def grad_hook(parameter: torch.Tensor): progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - # noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) - # noise_scheduler_copy = copy.deepcopy(noise_scheduler) + # only used to get timesteps, etc. TODO manage timesteps etc. separately + dummy_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) if accelerator.is_main_process: init_kwargs = {} @@ -844,9 +844,7 @@ def grad_hook(parameter: torch.Tensor): # 1, # ) # calculate loss - loss = train_util.conditional_loss( - args, model_pred.float(), target.float(), timesteps, "none", None - ) + loss = train_util.conditional_loss(args, model_pred.float(), target.float(), timesteps, "none", dummy_scheduler) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) From 0fe6320f09a61859c3faa134affb810cb42b62cd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 14:13:37 +0900 Subject: [PATCH 264/348] fix flux_train.py is not working --- flux_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux_train.py b/flux_train.py index f6e43b27a..cfe14885e 100644 --- a/flux_train.py +++ b/flux_train.py @@ -667,7 +667,7 @@ def grad_hook(parameter: torch.Tensor): # calculate loss loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler + args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if weighting is not None: loss = loss * weighting From cc11989755d0dd61f10eeec85983c751fd7ebb47 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 21:20:28 +0900 Subject: [PATCH 265/348] fix: refactor huber-loss calculation in multiple training scripts --- fine_tune.py | 13 ++++--------- flux_train.py | 5 ++--- library/train_util.py | 21 +++++++++++---------- sd3_train.py | 3 ++- sdxl_train.py | 13 ++++--------- sdxl_train_control_net.py | 9 +++------ sdxl_train_control_net_lllite.py | 9 +++------ sdxl_train_control_net_lllite_old.py | 10 ++++++---- train_controlnet.py | 11 +++++------ train_db.py | 9 +++------ train_network.py | 5 ++--- train_textual_inversion.py | 5 ++--- train_textual_inversion_XTI.py | 9 +++++---- 13 files changed, 52 insertions(+), 70 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 401a40f08..176087065 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -380,9 +380,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -394,11 +392,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): else: target = noise + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -410,9 +407,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "mean", noise_scheduler - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: diff --git a/flux_train.py b/flux_train.py index cfe14885e..fced3bef9 100644 --- a/flux_train.py +++ b/flux_train.py @@ -666,9 +666,8 @@ def grad_hook(parameter: torch.Tensor): target = noise - latents # calculate loss - loss = train_util.conditional_loss( - args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) if weighting is not None: loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): diff --git a/library/train_util.py b/library/train_util.py index eaf6ec004..fe74ddc7e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5869,7 +5869,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): return noise, noisy_latents, timesteps -def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor: +def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: + if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): + return None + b_size = timesteps.shape[0] if args.huber_schedule == "exponential": alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps @@ -5890,22 +5893,20 @@ def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch def conditional_loss( - args, model_pred: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor, reduction: str, noise_scheduler + model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None ): - if args.loss_type == "l2": + if loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) - elif args.loss_type == "l1": + elif loss_type == "l1": loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) - elif args.loss_type == "huber": - huber_c = get_huber_threshold(args, timesteps, noise_scheduler) + elif loss_type == "huber": huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) - elif args.loss_type == "smooth_l1": - huber_c = get_huber_threshold(args, timesteps, noise_scheduler) + elif loss_type == "smooth_l1": huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": @@ -5913,7 +5914,7 @@ def conditional_loss( elif reduction == "sum": loss = torch.sum(loss) else: - raise NotImplementedError(f"Unsupported Loss Type: {args.loss_type}") + raise NotImplementedError(f"Unsupported Loss Type: {loss_type}") return loss @@ -5923,7 +5924,7 @@ def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): names.append("unet") names.append("text_encoder1") names.append("text_encoder2") - names.append("text_encoder3") # SD3 + names.append("text_encoder3") # SD3 append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) diff --git a/sd3_train.py b/sd3_train.py index 73a68aa6a..120455e7b 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -844,7 +844,8 @@ def grad_hook(parameter: torch.Tensor): # 1, # ) # calculate loss - loss = train_util.conditional_loss(args, model_pred.float(), target.float(), timesteps, "none", dummy_scheduler) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, dummy_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/sdxl_train.py b/sdxl_train.py index 1bc27ec6c..b9d529243 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -695,9 +695,7 @@ def optimizer_hook(parameter: torch.Tensor): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -711,6 +709,7 @@ def optimizer_hook(parameter: torch.Tensor): else: target = noise + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) if ( args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred @@ -719,9 +718,7 @@ def optimizer_hook(parameter: torch.Tensor): or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -737,9 +734,7 @@ def optimizer_hook(parameter: torch.Tensor): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c) accelerator.backward(loss) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index d0051d18f..01387409a 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -512,9 +512,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) @@ -533,9 +531,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 66214f5df..365059b75 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -463,9 +463,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -484,9 +482,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 5e10654b9..5b372befc 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -12,6 +12,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -324,7 +325,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -426,9 +429,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_controlnet.py b/train_controlnet.py index da7a08d69..177d2b11f 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -307,10 +307,12 @@ def __contains__(self, name): if args.fused_backward_pass: import library.adafactor_fused + library.adafactor_fused.patch_adafactor_fused(optimizer) for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: + def __grad_hook(tensor: torch.Tensor, param_group=param_group): if accelerator.sync_gradients and args.max_grad_norm != 0.0: accelerator.clip_grad_norm_(tensor, args.max_grad_norm) @@ -464,9 +466,7 @@ def remove_model(old_ckpt_name): ) # Sample a random timestep for each image - timesteps = train_util.get_timesteps( - 0, noise_scheduler.config.num_train_timesteps, b_size, latents.device - ) + timesteps = train_util.get_timesteps(0, noise_scheduler.config.num_train_timesteps, b_size, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -498,9 +498,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_db.py b/train_db.py index a185b31b3..ad21f8d1b 100644 --- a/train_db.py +++ b/train_db.py @@ -370,9 +370,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -384,9 +382,8 @@ def train(args): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/train_network.py b/train_network.py index c7d4f5dc5..0b4208187 100644 --- a/train_network.py +++ b/train_network.py @@ -1207,9 +1207,8 @@ def remove_model(old_ckpt_name): train_unet, ) - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if weighting is not None: loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 9e1e57c48..65da4859b 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -601,9 +601,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 944733602..2a2b42310 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -407,7 +407,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) # function for saving/removing @@ -473,9 +475,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) From 14760407871c7eaa26210c7db71ce2740a817c4c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 21:26:39 +0900 Subject: [PATCH 266/348] fix: update help text for huber loss parameters in train_util.py --- library/train_util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index fe74ddc7e..a40983a68 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3905,14 +3905,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--huber_c", type=float, default=0.1, - help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1" + " / Huber損失の減衰パラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", ) parser.add_argument( "--huber_scale", type=float, default=1.0, - help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0" + " / Huber損失のスケールパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは1.0", ) parser.add_argument( From 34e7f509c41491f9a08c16c8ead2adf5cb210ec1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 21:36:24 +0900 Subject: [PATCH 267/348] docs: update README for huber loss --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index f9c85e3ac..89a96827c 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,11 @@ The command to install PyTorch is as follows: ### Recent Updates +1 Dec, 2024: + +- Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris! + - Specify `--loss_type huber` or `--loss_type smooth_l1` to use it. `--huber_c` and `--huber_scale` are also available. + Nov 14, 2024: - Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM. From 1dc873d9b463d50e27ae8572c28a473ce9a1254f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 22:00:44 +0900 Subject: [PATCH 268/348] update README and clean up code for schedulefree optimizer --- README.md | 4 +++- library/train_util.py | 7 +++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 89a96827c..8db5c4d42 100644 --- a/README.md +++ b/README.md @@ -16,9 +16,11 @@ The command to install PyTorch is as follows: 1 Dec, 2024: -- Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris! +- Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See PR [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris! - Specify `--loss_type huber` or `--loss_type smooth_l1` to use it. `--huber_c` and `--huber_scale` are also available. +- [Prodigy + ScheduleFree](https://github.com/LoganBooker/prodigy-plus-schedule-free) is supported. See PR [#1811](https://github.com/kohya-ss/sd-scripts/pull/1811) for details. Thanks to rockerBOO! + Nov 14, 2024: - Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM. diff --git a/library/train_util.py b/library/train_util.py index 289ab8235..6cfd14d5e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4609,7 +4609,7 @@ def task(): def get_optimizer(args, trainable_params): # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" - + optimizer_type = args.optimizer_type if args.use_8bit_adam: assert ( @@ -4883,7 +4883,6 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type.endswith("schedulefree".lower()): - should_train_optimizer = True try: import schedulefree as sf except ImportError: @@ -5000,8 +4999,8 @@ def __instancecheck__(self, instance): optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) - if hasattr(optimizer, 'train') and callable(optimizer.train): - # make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop + if hasattr(optimizer, "train") and callable(optimizer.train): + # make optimizer as train mode before training for schedulefree optimizer. the optimizer will be in eval mode in sampling and saving. optimizer.train() return optimizer_name, optimizer_args, optimizer From e369b9a252b90d1f57ea20dd6f5d05ec0c287ae1 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Mon, 2 Dec 2024 23:38:54 +0900 Subject: [PATCH 269/348] docs: update README with FLUX.1 ControlNet training details and improve argument help text --- README.md | 10 +++++++++- library/flux_train_utils.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 45e3cb7ab..6a5cdd342 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,15 @@ The command to install PyTorch is as follows: ### Recent Updates -1 Dec, 2024: +Dec 2, 2024: + +- FLUX.1 ControlNet training is supported. PR [#1813](https://github.com/kohya-ss/sd-scripts/pull/1813). Thanks to minux302! See PR and [here](#flux1-controlnet-training) for details. + - Not fully tested. Feedback is welcome. + - 80GB VRAM is required for 1024x1024 resolution, and 48GB VRAM is required for 512x512 resolution. + - Currently, it only works in Linux environment (or Windows WSL2) because DeepSpeed is required. + - Multi-GPU training is not tested. + +Dec 1, 2024: - Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See PR [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris! - Specify `--loss_type huber` or `--loss_type smooth_l1` to use it. `--huber_c` and `--huber_scale` are also available. diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 5e25c7feb..de2e2b48d 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -567,7 +567,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): "--controlnet", type=str, default=None, - help="path to controlnet (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)" + help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)" ) parser.add_argument( "--t5xxl_max_token_length", From 5ab00f9b49b5a3958bb0267fdb9236a96d503dbd Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 2 Dec 2024 13:39:51 -0500 Subject: [PATCH 270/348] Update workflow tests with cleanup and documentation --- .github/workflows/tests.yml | 48 +++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9ae67b0e9..5a790d570 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,42 +1,44 @@ - -name: Python package - -on: [push] +name: Test with pytet + +on: + push: + branches: + - main + - dev + - sd3 + pull_request: + branches: + - main + - dev + - sd3 jobs: build: - runs-on: ${{ matrix.os }} strategy: matrix: os: [ubuntu-latest] - python-version: ["3.10"] + python-version: ["3.10"] # Python versions to test + pytorch-version: ["2.4.0"] # PyTorch versions to test steps: - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 + - uses: actions/setup-python@v5 with: - python-version: '3.x' - - - name: Install dependencies - run: python -m pip install --upgrade pip setuptools wheel + python-version: ${{ matrix.python-version }} + cache: 'pip' - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.x' - cache: 'pip' # caching pip dependencies + - name: Install and update pip, setuptools, wheel + run: | + # Setuptools, wheel for compiling some packages + python -m pip install --upgrade pip setuptools wheel - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install dadaptation==3.2 torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0 + # Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch) + pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision==0.19.0 pytest==8.3.4 pip install -r requirements.txt - name: Test with pytest - run: | - pip install pytest - pytest + run: pytest # See pytest.ini for configuration From 63738ecb0758a02555392d2c283a83bba1c6f98e Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 2 Dec 2024 13:48:30 -0500 Subject: [PATCH 271/348] Add tests documentation --- tests/README.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 tests/README.md diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 000000000..19eeab0e2 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,32 @@ +# Tests + +## Install + +``` +pip install pytest +``` + +## Usage + +``` +pytest +``` + +## Contribution + +Pytest is configured to run tests in this directory. It might be a good idea to add tests closer in the code, as well as doctests. + +Tests are functions starting with `test_` and files with the pattern `test_*.py`. + +``` +def test_x(): + assert 1 == 2, "Invalid test response" +``` + +## Resources + +- https://circleci.com/blog/testing-pytorch-model-with-pytest/ +- https://pytorch.org/docs/stable/testing.html +- https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests +- https://github.com/huggingface/pytorch-image-models/tree/main/tests +- https://github.com/pytorch/pytorch/tree/main/test From 2610e96e9e3d0605d5a16615efa26ae8935ed3aa Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 2 Dec 2024 13:49:58 -0500 Subject: [PATCH 272/348] Pytest --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5a790d570..672a657bf 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,4 +1,4 @@ -name: Test with pytet +name: Test with pytest on: push: From 3e5d89c76c287872e20c4a967d36b51384285be8 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 2 Dec 2024 13:51:57 -0500 Subject: [PATCH 273/348] Add more resources --- tests/README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/README.md b/tests/README.md index 19eeab0e2..9836da8b4 100644 --- a/tests/README.md +++ b/tests/README.md @@ -25,8 +25,17 @@ def test_x(): ## Resources +### pytest + +- https://docs.pytest.org/en/stable/index.html +- https://docs.pytest.org/en/stable/how-to/assert.html +- https://docs.pytest.org/en/stable/how-to/doctest.html + +### PyTorch testing + - https://circleci.com/blog/testing-pytorch-model-with-pytest/ - https://pytorch.org/docs/stable/testing.html - https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests - https://github.com/huggingface/pytorch-image-models/tree/main/tests - https://github.com/pytorch/pytorch/tree/main/test + From 8b36d907d8635dca64224574b5cb15013e00809d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 3 Dec 2024 08:43:26 +0900 Subject: [PATCH 274/348] feat: support block_to_swap for FLUX.1 ControlNet training --- README.md | 13 +++++++++++ flux_train_control_net.py | 46 +++++++++++++++++++++++++++------------ 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 6a5cdd342..f02725191 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,11 @@ The command to install PyTorch is as follows: ### Recent Updates + +Dec 3, 2024: + +-`--blocks_to_swap` now works in FLUX.1 ControlNet training. Sample commands for 24GB VRAM and 16GB VRAM are added [here](#flux1-controlnet-training). + Dec 2, 2024: - FLUX.1 ControlNet training is supported. PR [#1813](https://github.com/kohya-ss/sd-scripts/pull/1813). Thanks to minux302! See PR and [here](#flux1-controlnet-training) for details. @@ -276,6 +281,14 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_tr --timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --deepspeed ``` +For 24GB VRAM GPUs, you can train with 16 blocks swapped and caching latents and text encoder outputs with the batch size of 1. Remove `--deepspeed` . Sample command is below. Not fully tested. +``` + --blocks_to_swap 16 --cache_latents_to_disk --cache_text_encoder_outputs_to_disk +``` + +The training can be done with 16GB VRAM GPUs with around 30 blocks swapped. + +`--gradient_accumulation_steps` is also available. The default value is 1 (no accumulation), but according to the original PR, 8 is used. ### FLUX.1 OFT training diff --git a/flux_train_control_net.py b/flux_train_control_net.py index bb27c35ed..5548fd991 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -119,9 +119,7 @@ def train(args): "datasets": [ { "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( - args.train_data_dir, - args.conditioning_data_dir, - args.caption_extension + args.train_data_dir, args.conditioning_data_dir, args.caption_extension ) } ] @@ -263,13 +261,17 @@ def train(args): args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) flux.requires_grad_(False) - flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet(args.controlnet, is_schnell, torch.float32, accelerator.device, args.disable_mmap_load_safetensors) + controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype + controlnet = flux_utils.load_controlnet( + args.controlnet, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors + ) controlnet.train() if args.gradient_checkpointing: + if not args.deepspeed: + flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) # block swap @@ -296,7 +298,11 @@ def train(args): # This idea is based on 2kpr's great work. Thank you! logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") flux.enable_block_swap(args.blocks_to_swap, accelerator.device) - controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) + flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + # ControlNet only has two blocks, so we can keep it on GPU + # controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) + else: + flux.to(accelerator.device) if not cache_latents: # load VAE here if not cached @@ -455,9 +461,7 @@ def train(args): else: # accelerator does some magic # if we doesn't swap blocks, we can move the model to device - controlnet = accelerator.prepare(controlnet, device_placement=[not is_swapping_blocks]) - if is_swapping_blocks: - accelerator.unwrap_model(controlnet).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + controlnet = accelerator.prepare(controlnet) # , device_placement=[not is_swapping_blocks]) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -564,11 +568,13 @@ def grad_hook(parameter: torch.Tensor): ) if is_swapping_blocks: - accelerator.unwrap_model(controlnet).prepare_block_swap_before_forward() + flux.prepare_block_swap_before_forward() # For --sample_at_first optimizer_eval_fn() - flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet) + flux_train_utils.sample_images( + accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet + ) optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb @@ -629,7 +635,11 @@ def grad_hook(parameter: torch.Tensor): # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device).to(weight_dtype) + img_ids = ( + flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width) + .to(device=accelerator.device) + .to(weight_dtype) + ) # get guidance: ensure args.guidance_scale is float guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype) @@ -638,7 +648,7 @@ def grad_hook(parameter: torch.Tensor): l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds if not args.apply_t5_attn_mask: t5_attn_mask = None - + with accelerator.autocast(): block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, @@ -715,7 +725,15 @@ def grad_hook(parameter: torch.Tensor): optimizer_eval_fn() flux_train_utils.sample_images( - accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet + accelerator, + args, + None, + global_step, + flux, + ae, + [clip_l, t5xxl], + sample_prompts_te_outputs, + controlnet=controlnet, ) # 指定ステップごとにモデルを保存 From 6bee18db4fbf62ebd2a1da88a5851c48f2e06c54 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 7 Dec 2024 15:12:27 +0900 Subject: [PATCH 275/348] fix: resolve model corruption issue with pos_embed when using --enable_scaled_pos_embed --- README.md | 2 ++ library/sd3_models.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f02725191..6162359d1 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,8 @@ The command to install PyTorch is as follows: ### Recent Updates +Dec 7, 2024: +- Fixed an issue where the saved model would be corrupted (pos_embed would not be saved) when `--enable_scaled_pos_embed` was specified in `sd3_train.py`. Dec 3, 2024: diff --git a/library/sd3_models.py b/library/sd3_models.py index 2f3c82eed..e4a931861 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -870,8 +870,10 @@ def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Opti self.use_scaled_pos_embed = use_scaled_pos_embed if self.use_scaled_pos_embed: - # remove pos_embed to free up memory up to 0.4 GB - self.pos_embed = None + # # remove pos_embed to free up memory up to 0.4 GB -> this causes error because pos_embed is not saved + # self.pos_embed = None + # move pos_embed to CPU to free up memory up to 0.4 GB + self.pos_embed = self.pos_embed.cpu() # remove duplicates and sort latent sizes in ascending order latent_sizes = list(set(latent_sizes)) From abff4b0ec7bb37b338924e38392593f2bea2b8d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Sat, 7 Dec 2024 16:12:46 +0800 Subject: [PATCH 276/348] Unify controlnet parameters name and change scripts name. (#1821) * Update sd3_train.py * add freeze block lr * Update train_util.py * update * Revert "add freeze block lr" This reverts commit 8b1653548f8f219e5be2cde96f65a8813cf9ea1f. # Conflicts: # library/train_util.py # sd3_train.py * use same control net model path * use controlnet_model_name_or_path --- flux_train_control_net.py | 2 +- library/flux_train_utils.py | 2 +- sdxl_train_control_net.py | 8 ++++---- train_controlnet.py => train_control_net.py | 0 4 files changed, 6 insertions(+), 6 deletions(-) rename train_controlnet.py => train_control_net.py (100%) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 5548fd991..9d36a41d3 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -265,7 +265,7 @@ def train(args): # load controlnet controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype controlnet = flux_utils.load_controlnet( - args.controlnet, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors + args.controlnet_model_name_or_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors ) controlnet.train() diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index de2e2b48d..f7f06c5cf 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -564,7 +564,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): ) parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") parser.add_argument( - "--controlnet", + "--controlnet_model_name_or_path", type=str, default=None, help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)" diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 01387409a..ffbf03cab 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -184,12 +184,12 @@ def unwrap_model(model): # make control net logger.info("make ControlNet") - if args.controlnet_model_path: + if args.controlnet_model_name_or_path: with init_empty_weights(): control_net = SdxlControlNet() - logger.info(f"load ControlNet from {args.controlnet_model_path}") - filename = args.controlnet_model_path + logger.info(f"load ControlNet from {args.controlnet_model_name_or_path}") + filename = args.controlnet_model_name_or_path if os.path.splitext(filename)[1] == ".safetensors": state_dict = load_file(filename) else: @@ -675,7 +675,7 @@ def setup_parser() -> argparse.ArgumentParser: sdxl_train_util.add_sdxl_training_arguments(parser) parser.add_argument( - "--controlnet_model_path", + "--controlnet_model_name_or_path", type=str, default=None, help="controlnet model name or path / controlnetのモデル名またはパス", diff --git a/train_controlnet.py b/train_control_net.py similarity index 100% rename from train_controlnet.py rename to train_control_net.py From e425996a5953f0479384e70b6490e751c2d00b1f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 7 Dec 2024 17:28:19 +0900 Subject: [PATCH 277/348] feat: unify ControlNet model name option and deprecate old training script --- README.md | 7 +++++++ train_controlnet.py | 23 +++++++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 train_controlnet.py diff --git a/README.md b/README.md index 6162359d1..67836ddf0 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,13 @@ The command to install PyTorch is as follows: ### Recent Updates Dec 7, 2024: + +- The option to specify the model name during ControlNet training was different in each script. It has been unified. Please specify `--controlnet_model_name_or_path`. PR [#1821](https://github.com/kohya-ss/sd-scripts/pull/1821) Thanks to sdbds! + + - Fixed an issue where the saved model would be corrupted (pos_embed would not be saved) when `--enable_scaled_pos_embed` was specified in `sd3_train.py`. Dec 3, 2024: diff --git a/train_controlnet.py b/train_controlnet.py new file mode 100644 index 000000000..365e35c8c --- /dev/null +++ b/train_controlnet.py @@ -0,0 +1,23 @@ +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +from library import train_util +from train_control_net import setup_parser, train + +if __name__ == "__main__": + logger.warning( + "The module 'train_controlnet.py' is deprecated. Please use 'train_control_net.py' instead" + " / 'train_controlnet.py'は非推奨です。代わりに'train_control_net.py'を使用してください。" + ) + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) From 3cb8cb2d4fd697a49135193ac0873204e0139e62 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 9 Dec 2024 15:20:04 -0500 Subject: [PATCH 278/348] Prevent git credentials from leaking into other actions --- .github/workflows/tests.yml | 4 ++++ .github/workflows/typos.yml | 3 +++ 2 files changed, 7 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 672a657bf..2eddedc7b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,6 +23,10 @@ jobs: steps: - uses: actions/checkout@v4 + with: + # https://woodruffw.github.io/zizmor/audits/#artipacked + persist-credentials: false + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index 87ebdf894..f53cda218 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -18,6 +18,9 @@ jobs: steps: - uses: actions/checkout@v4 + with: + # https://woodruffw.github.io/zizmor/audits/#artipacked + persist-credentials: false - name: typos-action uses: crate-ci/typos@v1.28.1 From 8e378cf03df645cef897a342559dc5fa7f66a35d Mon Sep 17 00:00:00 2001 From: nhamanasu Date: Wed, 11 Dec 2024 19:43:44 +0900 Subject: [PATCH 279/348] add RAdamScheduleFree support --- library/train_util.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index a35388fee..72b5b24db 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4887,7 +4887,11 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]: import schedulefree as sf except ImportError: raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") - if optimizer_type == "AdamWScheduleFree".lower(): + + if optimizer_type == "RAdamScheduleFree".lower(): + optimizer_class = sf.RAdamScheduleFree + logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "AdamWScheduleFree".lower(): optimizer_class = sf.AdamWScheduleFree logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") elif optimizer_type == "SGDScheduleFree".lower(): From e89653975ddf429cdf0c0fd268da0a5a3e8dba1f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 15 Dec 2024 19:39:47 +0900 Subject: [PATCH 280/348] update requirements.txt and README to include RAdamScheduleFree optimizer support --- README.md | 6 ++++++ requirements.txt | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 67836ddf0..bfb22bcf1 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,12 @@ The command to install PyTorch is as follows: ### Recent Updates +Dec 15, 2024: + +- RAdamScheduleFree optimizer is supported. PR [#1830](https://github.com/kohya-ss/sd-scripts/pull/1830) Thanks to nhamanasu! + - Update to `schedulefree==1.4` is required. Please update individually or with `pip install --use-pep517 --upgrade -r requirements.txt`. + - Available with `--optimizer_type=RAdamScheduleFree`. No need to specify warm up steps as well as learning rate scheduler. + Dec 7, 2024: - The option to specify the model name during ControlNet training was different in each script. It has been unified. Please specify `--controlnet_model_name_or_path`. PR [#1821](https://github.com/kohya-ss/sd-scripts/pull/1821) Thanks to sdbds! diff --git a/requirements.txt b/requirements.txt index 0dd1c69cc..e0091749a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ pytorch-lightning==1.9.0 bitsandbytes==0.44.0 prodigyopt==1.0 lion-pytorch==0.0.6 -schedulefree==1.2.7 +schedulefree==1.4 tensorboard safetensors==0.4.4 # gradio==3.16.2 From 05bb9183fae18c62a1730fe5060f80c0b99a21f3 Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Fri, 27 Dec 2024 16:47:59 +0800 Subject: [PATCH 281/348] Add Validation loss for LoRA training --- library/config_util.py | 78 +++++++++++++++++++++++- library/train_util.py | 54 ++++++++++++++++- train_network.py | 131 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 257 insertions(+), 6 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 12d0be173..a57cd36f0 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -73,6 +73,8 @@ class BaseSubsetParams: token_warmup_min: int = 1 token_warmup_step: float = 0 custom_attributes: Optional[Dict[str, Any]] = None + validation_seed: int = 0 + validation_split: float = 0.0 @dataclass @@ -102,6 +104,8 @@ class BaseDatasetParams: resolution: Optional[Tuple[int, int]] = None network_multiplier: float = 1.0 debug_dataset: bool = False + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass @@ -478,9 +482,27 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) + val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.params.validation_split <= 0.0: + continue + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) + val_datasets.append(dataset) + # print info info = "" for i, dataset in enumerate(datasets): @@ -566,6 +588,50 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu logger.info(f"{info}") + if len(val_datasets) > 0: + info = "" + + for i, dataset in enumerate(val_datasets): + info += dedent( + f"""\ + [Validation Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} + """ + ) + + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n""" + ), + " ", + ) + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ + [Subset {j} of Validation Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + """ + ), + " ", + ) + + logger.info(f"{info}") + # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no @@ -574,7 +640,15 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset.make_buckets() dataset.set_seed(seed) - return DatasetGroup(datasets) + for i, dataset in enumerate(val_datasets): + logger.info(f"[Validation Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + return ( + DatasetGroup(datasets), + DatasetGroup(val_datasets) if val_datasets else None + ) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): diff --git a/library/train_util.py b/library/train_util.py index 72b5b24db..a3fa98e99 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -145,6 +145,17 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" +def split_train_val(paths: List[str], validation_split: float, validation_seed: int) -> List[str]: + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + return paths[len(paths) - round(len(paths) * validation_split):] class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: @@ -397,6 +408,8 @@ def __init__( token_warmup_min: int, token_warmup_step: Union[float, int], custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -424,6 +437,9 @@ def __init__( self.img_count = 0 + self.validation_seed = validation_seed + self.validation_split = validation_split + class DreamBoothSubset(BaseSubset): def __init__( @@ -453,6 +469,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -478,6 +496,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.is_reg = is_reg @@ -518,6 +538,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -543,6 +565,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.metadata_file = metadata_file @@ -579,6 +603,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -604,6 +630,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.conditioning_data_dir = conditioning_data_dir @@ -1799,6 +1827,9 @@ def __init__( bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset: bool, + is_train: bool, + validation_seed: int, + validation_split: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -1808,6 +1839,9 @@ def __init__( self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight self.latents_cache = None + self.is_train = is_train + self.validation_seed = validation_seed + self.validation_split = validation_split self.enable_bucket = enable_bucket if self.enable_bucket: @@ -1992,6 +2026,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset): ) continue + if self.is_train == False: + img_paths = split_train_val(img_paths, self.validation_split, self.validation_seed) + if subset.is_reg: num_reg_images += subset.num_repeats * len(img_paths) else: @@ -2009,7 +2046,11 @@ def load_dreambooth_dir(subset: DreamBoothSubset): subset.img_count = len(img_paths) self.subsets.append(subset) - logger.info(f"{num_train_images} train images with repeating.") + if self.is_train: + logger.info(f"{num_train_images} train images with repeating.") + else: + logger.info(f"{num_train_images} validation images with repeating.") + self.num_train_images = num_train_images logger.info(f"{num_reg_images} reg images.") @@ -2050,6 +2091,9 @@ def __init__( bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: bool, + is_train: bool, + validation_seed: int, + validation_split: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2276,6 +2320,9 @@ def __init__( bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: float, + is_train: bool, + validation_seed: int, + validation_split: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2324,6 +2371,9 @@ def __init__( bucket_no_upscale, 1.0, debug_dataset, + is_train, + validation_seed, + validation_split, ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) @@ -4887,7 +4937,7 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]: import schedulefree as sf except ImportError: raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") - + if optimizer_type == "RAdamScheduleFree".lower(): optimizer_class = sf.RAdamScheduleFree logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}") diff --git a/train_network.py b/train_network.py index 5e82b307c..776feaf76 100644 --- a/train_network.py +++ b/train_network.py @@ -9,6 +9,7 @@ from multiprocessing import Value from typing import Any, List import toml +import itertools from tqdm import tqdm @@ -114,7 +115,7 @@ def generate_step_logs( ) if ( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None - ): + ): logs[f"lr/d*lr/group{i}"] = ( optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] ) @@ -373,10 +374,11 @@ def train(self, args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -398,6 +400,11 @@ def train(self, args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + if val_dataset_group is not None: + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + self.assert_extra_args(args, train_dataset_group) # may change some args # acceleratorを準備する @@ -444,6 +451,8 @@ def train(self, args): vae.eval() train_dataset_group.new_cache_latents(vae, accelerator) + if val_dataset_group is not None: + val_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -459,6 +468,8 @@ def train(self, args): if text_encoder_outputs_caching_strategy is not None: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype) + if val_dataset_group is not None: + self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype) # prepare network net_kwargs = {} @@ -567,6 +578,8 @@ def train(self, args): # strategies are set here because they cannot be referenced in another process. Copy them with the dataset # some strategies can be None train_dataset_group.set_current_strategies() + if val_dataset_group is not None: + val_dataset_group.set_current_strategies() # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -580,6 +593,17 @@ def train(self, args): persistent_workers=args.persistent_data_loader_workers, ) + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + batch_size=1, + shuffle=False, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + cyclic_val_dataloader = itertools.cycle(val_dataloader) + # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( @@ -592,6 +616,10 @@ def train(self, args): # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) + # Not for sure here. + # if val_dataset_group is not None: + # val_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -1064,7 +1092,11 @@ def load_model_hook(models, input_dir): ) loss_recorder = train_util.LossRecorder() + # val_loss_recorder = train_util.LossRecorder() + del train_dataset_group + if val_dataset_group is not None: + del val_dataset_group # callback for step start if hasattr(accelerator.unwrap_model(network), "on_step_start"): @@ -1308,6 +1340,77 @@ def remove_model(old_ckpt_name): ) accelerator.log(logs, step=global_step) + if len(val_dataloader) > 0: + if ((args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps): + accelerator.print("\nValidating バリデーション処理...") + + total_loss = 0.0 + + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"): + batch = next(cyclic_val_dataloader) + + timesteps_list = [10, 350, 500, 650, 990] + + val_loss = 0.0 + + for fixed_timesteps in timesteps_list: + with torch.set_grad_enabled(False), accelerator.autocast(): + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] + + timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu") + timesteps = timesteps.long().to(latents.device) + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + with accelerator.autocast(): + noise_pred = self.call_unet( + args, + accelerator, + unet, + noisy_latents.requires_grad_(False), + timesteps, + text_encoder_conds, + batch, + weight_dtype, + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + val_loss += loss / len(timesteps_list) + + total_loss += val_loss.detach().item() + + current_val_loss = total_loss / validation_steps + # val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_val_loss) + + if len(accelerator.trackers) > 0: + logs = {"loss/current_val_loss": current_val_loss} + accelerator.log(logs, step=global_step) + + # avr_loss: float = val_loss_recorder.moving_average + # logs = {"loss/average_val_loss": avr_loss} + # accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break @@ -1496,6 +1599,30 @@ def setup_parser() -> argparse.ArgumentParser: help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", ) + parser.add_argument( + "--validation_seed", + type=int, + default=None, + help="Validation seed / 検証シード" + ) + parser.add_argument( + "--validation_split", + type=float, + default=0.0, + help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" + ) + parser.add_argument( + "--validation_every_n_step", + type=int, + default=None, + help="Number of train steps for counting validation loss. By default, validation per train epoch is performed / 学習エポックごとに検証を行う場合はNoneを指定する" + ) + parser.add_argument( + "--max_validation_steps", + type=int, + default=None, + help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset / 検証データセット全体を検証する場合はNoneを指定する" + ) # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") From 62164e57925125ed6268983ffa441f1ffecc0e6d Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Fri, 27 Dec 2024 17:28:05 +0800 Subject: [PATCH 282/348] Change val loss calculate method --- train_network.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/train_network.py b/train_network.py index 776feaf76..5fd1b212f 100644 --- a/train_network.py +++ b/train_network.py @@ -1383,16 +1383,20 @@ def remove_model(old_ckpt_name): else: target = noise - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) - if weighting is not None: - loss = loss * weighting - if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) + # huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + # loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + # if weighting is not None: + # loss = loss * weighting + # if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + # loss = apply_masked_loss(loss, batch) + # loss = loss.mean([1, 2, 3]) # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. - loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + # loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし From 64bd5317dc9cb39d69ab7728f36b03157c9b341f Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Sat, 28 Dec 2024 11:42:15 +0800 Subject: [PATCH 283/348] Split val latents/batch and pick up val latents shape size which equal to training batch. --- train_network.py | 45 +++++++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/train_network.py b/train_network.py index 5fd1b212f..6bce9e964 100644 --- a/train_network.py +++ b/train_network.py @@ -1349,7 +1349,27 @@ def remove_model(old_ckpt_name): with torch.no_grad(): validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"): - batch = next(cyclic_val_dataloader) + + while True: + val_batch = next(cyclic_val_dataloader) + + if "latents" in val_batch and val_batch["latents"] is not None: + val_latents = val_batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + # latentに変換 + val_latents = self.encode_images_to_latents(args, accelerator, vae, val_batch["images"].to(vae_dtype)) + val_latents = val_latents.to(dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(val_latents)): + accelerator.print("NaN found in validation latents, replacing with zeros") + val_latents = torch.nan_to_num(val_latents, 0, out=val_latents) + + val_latents = self.shift_scale_latents(args, val_latents) + + if val_latents.shape == latents.shape: + break timesteps_list = [10, 350, 500, 650, 990] @@ -1357,13 +1377,13 @@ def remove_model(old_ckpt_name): for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(False), accelerator.autocast(): - noise = torch.randn_like(latents, device=latents.device) - b_size = latents.shape[0] + noise = torch.randn_like(val_latents, device=val_latents.device) + b_size = val_latents.shape[0] timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu") - timesteps = timesteps.long().to(latents.device) + timesteps = timesteps.long().to(val_latents.device) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(val_latents, noise, timesteps) with accelerator.autocast(): noise_pred = self.call_unet( @@ -1373,27 +1393,16 @@ def remove_model(old_ckpt_name): noisy_latents.requires_grad_(False), timesteps, text_encoder_conds, - batch, + val_batch, weight_dtype, ) if args.v_parameterization: # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) + target = noise_scheduler.get_velocity(val_latents, noise, timesteps) else: target = noise - # huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - # loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) - # if weighting is not None: - # loss = loss * weighting - # if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - # loss = apply_masked_loss(loss, batch) - # loss = loss.mean([1, 2, 3]) - - # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. - # loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) From cb89e0284e1a25b41401861107159e6b943ee387 Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Sat, 28 Dec 2024 11:57:04 +0800 Subject: [PATCH 284/348] Change val latent loss compare --- train_network.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/train_network.py b/train_network.py index 6bce9e964..7276d5dc0 100644 --- a/train_network.py +++ b/train_network.py @@ -1350,6 +1350,8 @@ def remove_model(old_ckpt_name): validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"): + val_latents = None + while True: val_batch = next(cyclic_val_dataloader) @@ -1371,19 +1373,22 @@ def remove_model(old_ckpt_name): if val_latents.shape == latents.shape: break + if val_latents is not None: + del val_latents + timesteps_list = [10, 350, 500, 650, 990] val_loss = 0.0 for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(False), accelerator.autocast(): - noise = torch.randn_like(val_latents, device=val_latents.device) - b_size = val_latents.shape[0] + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu") - timesteps = timesteps.long().to(val_latents.device) + timesteps = timesteps.long().to(latents.device) - noisy_latents = noise_scheduler.add_noise(val_latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) with accelerator.autocast(): noise_pred = self.call_unet( @@ -1399,7 +1404,7 @@ def remove_model(old_ckpt_name): if args.v_parameterization: # v-parameterization training - target = noise_scheduler.get_velocity(val_latents, noise, timesteps) + target = noise_scheduler.get_velocity(latents, noise, timesteps) else: target = noise From 874353296304c753b452511a412472f8a3e4ba09 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 04:37:16 +0800 Subject: [PATCH 285/348] val --- library/config_util.py | 32 +++++++------ library/train_util.py | 20 ++++++-- train_network.py | 104 +++++++++++++++++++++++++++-------------- 3 files changed, 103 insertions(+), 53 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 1bf7ed955..cb2c5b68f 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -81,23 +81,24 @@ class ControlNetSubsetParams(BaseSubsetParams): @dataclass class BaseDatasetParams: - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None - max_token_length: int = None - resolution: Optional[Tuple[int, int]] = None - debug_dataset: bool = False - validation_seed: Optional[int] = None - validation_split: float = 0.0 + tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None + max_token_length: int = None + resolution: Optional[Tuple[int, int]] = None + network_multiplier: float = 1.0 + debug_dataset: bool = False + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass class DreamBoothDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False - prior_loss_weight: float = 1.0 - + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + prior_loss_weight: float = 1.0 + @dataclass class FineTuningDatasetParams(BaseDatasetParams): batch_size: int = 1 @@ -203,8 +204,9 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "max_bucket_reso": int, "min_bucket_reso": int, "validation_seed": int, - "validation_split": float, + "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + "network_multiplier": float, } # options handled by argparse but not handled by user config diff --git a/library/train_util.py b/library/train_util.py index 1979207b0..2364d62b3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -122,6 +122,20 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +def split_train_val(paths, is_train, validation_split, validation_seed): + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + if is_train: + return paths[0:math.ceil(len(paths) * (1 - validation_split))] + else: + return paths[len(paths) - round(len(paths) * validation_split):] def split_train_val(paths, is_train, validation_split, validation_seed): if validation_seed is not None: @@ -1352,7 +1366,6 @@ def __init__( self.is_train = is_train self.validation_split = validation_split self.validation_seed = validation_seed - self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight @@ -1405,10 +1418,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset): return [], [] img_paths = glob_images(subset.image_dir, "*") - if self.validation_split > 0.0: - img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) - print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] diff --git a/train_network.py b/train_network.py index edd3ff944..48885503f 100644 --- a/train_network.py +++ b/train_network.py @@ -130,7 +130,9 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) - def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): + def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True, timesteps_list=None): + total_loss = 0.0 + with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) @@ -167,37 +169,40 @@ def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, n args, noise_scheduler, latents ) - # Predict the noise residual - with torch.set_grad_enabled(is_train), accelerator.autocast(): - noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise + # Use input timesteps_list or use described timesteps above + timesteps_list = timesteps_list or [timesteps] + for timesteps in timesteps_list: + # Predict the noise residual + with torch.set_grad_enabled(is_train), accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise - loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight - loss = loss * loss_weights + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight + loss = loss * loss_weights - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - return loss + total_loss += loss.mean() # 平均なのでbatch_sizeで割る必要なし + average_loss = total_loss / len(timesteps_list) + return average_loss def train(self, args): session_id = random.randint(0, 2**32) @@ -283,10 +288,10 @@ def train(self, args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" if val_dataset_group is not None: - assert ( - val_dataset_group.is_latent_cacheable() - ), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + self.assert_extra_args(args, train_dataset_group) # acceleratorを準備する @@ -430,6 +435,15 @@ def train(self, args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) + + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) val_dataloader = torch.utils.data.DataLoader( val_dataset_group if val_dataset_group is not None else [], @@ -798,7 +812,6 @@ def train(self, args): loss_recorder = train_util.LossRecorder() val_loss_recorder = train_util.LossRecorder() - del train_dataset_group # callback for step start @@ -848,7 +861,6 @@ def remove_model(old_ckpt_name): on_step_start(text_encoder, unet) is_train = True loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=train_text_encoder) - accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = network.get_trainable_params() @@ -900,7 +912,25 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - + + if global_step % 25 == 0: + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + + with torch.no_grad(): + val_dataloader_iter = iter(val_dataloader) + batch = next(val_dataloader_iter) + is_train = False + loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, timesteps_list=[10, 350, 500, 650, 990]) + + current_loss = loss.detach().item() + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_current": current_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break @@ -912,7 +942,7 @@ def remove_model(old_ckpt_name): with torch.no_grad(): for val_step, batch in enumerate(val_dataloader): is_train = False - loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, timesteps_list=[10, 350, 500, 650, 990]) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) @@ -933,6 +963,12 @@ def remove_model(old_ckpt_name): logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) + if len(val_dataloader) > 0: + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_epoch_average": avr_loss} + accelerator.log(logs, step=epoch + 1) + accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 From 449c1c5c502375713e609ad9e00e747b4013063a Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 2 Jan 2025 15:59:20 -0500 Subject: [PATCH 286/348] Adding modified train_util and config_util --- library/config_util.py | 1 - library/train_util.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index cb2c5b68f..727e1a409 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -84,7 +84,6 @@ class BaseDatasetParams: tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None max_token_length: int = None resolution: Optional[Tuple[int, int]] = None - network_multiplier: float = 1.0 debug_dataset: bool = False validation_seed: Optional[int] = None validation_split: float = 0.0 diff --git a/library/train_util.py b/library/train_util.py index 2364d62b3..394337397 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1420,7 +1420,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): img_paths = glob_images(subset.image_dir, "*") if self.validation_split > 0.0: img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) - logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] From 7470173044ca5b700bc4723709bd9c012e2216f3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 01:13:57 -0500 Subject: [PATCH 287/348] Remove defunct code for train_controlnet.py --- train_controlnet.py | 569 -------------------------------------------- 1 file changed, 569 deletions(-) diff --git a/train_controlnet.py b/train_controlnet.py index 09a911a00..365e35c8c 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -6,577 +6,8 @@ logger = logging.getLogger(__name__) -<<<<<<< HEAD -# TODO 他のスクリプトと共通化する -def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): - logs = { - "loss/current": current_loss, - "loss/average": avr_loss, - "lr": lr_scheduler.get_last_lr()[0], - } - - if args.optimizer_type.lower().startswith("DAdapt".lower()): - logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - - return logs - - -def train(args): - # session_id = random.randint(0, 2**32) - # training_started_at = time.time() - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) - setup_logging(args, reset=True) - - cache_latents = args.cache_latents - use_user_config = args.dataset_config is not None - - if args.seed is None: - args.seed = random.randint(0, 2**32) - set_seed(args.seed) - - tokenizer = train_util.load_tokenizer(args) - - # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) - if use_user_config: - logger.info(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "conditioning_data_dir"] - if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - user_config = { - "datasets": [ - { - "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( - args.train_data_dir, - args.conditioning_data_dir, - args.caption_extension, - ) - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - if len(train_dataset_group) == 0: - logger.error( - "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" - ) - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # acceleratorを準備する - logger.info("prepare accelerator") - accelerator = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - - # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model( - args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True - ) - - # DiffusersのControlNetが使用するデータを準備する - if args.v2: - unet.config = { - "act_fn": "silu", - "attention_head_dim": [5, 10, 20, 20], - "block_out_channels": [320, 640, 1280, 1280], - "center_input_sample": False, - "cross_attention_dim": 1024, - "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], - "downsample_padding": 1, - "dual_cross_attention": False, - "flip_sin_to_cos": True, - "freq_shift": 0, - "in_channels": 4, - "layers_per_block": 2, - "mid_block_scale_factor": 1, - "norm_eps": 1e-05, - "norm_num_groups": 32, - "num_class_embeds": None, - "only_cross_attention": False, - "out_channels": 4, - "sample_size": 96, - "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], - "use_linear_projection": True, - "upcast_attention": True, - "only_cross_attention": False, - "downsample_padding": 1, - "use_linear_projection": True, - "class_embed_type": None, - "num_class_embeds": None, - "resnet_time_scale_shift": "default", - "projection_class_embeddings_input_dim": None, - } - else: - unet.config = { - "act_fn": "silu", - "attention_head_dim": 8, - "block_out_channels": [320, 640, 1280, 1280], - "center_input_sample": False, - "cross_attention_dim": 768, - "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], - "downsample_padding": 1, - "flip_sin_to_cos": True, - "freq_shift": 0, - "in_channels": 4, - "layers_per_block": 2, - "mid_block_scale_factor": 1, - "norm_eps": 1e-05, - "norm_num_groups": 32, - "out_channels": 4, - "sample_size": 64, - "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], - "only_cross_attention": False, - "downsample_padding": 1, - "use_linear_projection": False, - "class_embed_type": None, - "num_class_embeds": None, - "upcast_attention": False, - "resnet_time_scale_shift": "default", - "projection_class_embeddings_input_dim": None, - } - unet.config = SimpleNamespace(**unet.config) - - controlnet = ControlNetModel.from_unet(unet) - - if args.controlnet_model_name_or_path: - filename = args.controlnet_model_name_or_path - if os.path.isfile(filename): - if os.path.splitext(filename)[1] == ".safetensors": - state_dict = load_file(filename) - else: - state_dict = torch.load(filename) - state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) - controlnet.load_state_dict(state_dict) - elif os.path.isdir(filename): - controlnet = ControlNetModel.from_pretrained(filename) - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) - vae.to("cpu") - clean_memory_on_device(accelerator.device) - - accelerator.wait_for_everyone() - - if args.gradient_checkpointing: - controlnet.enable_gradient_checkpointing() - - # 学習に必要なクラスを準備する - accelerator.print("prepare optimizer, data loader etc.") - - trainable_params = controlnet.parameters() - - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - accelerator.print( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" - ) - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert ( - args.mixed_precision == "fp16" - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - accelerator.print("enable full fp16 training.") - controlnet.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - controlnet, optimizer, train_dataloader, lr_scheduler - ) - - unet.requires_grad_(False) - text_encoder.requires_grad_(False) - unet.to(accelerator.device) - text_encoder.to(accelerator.device) - - # transform DDP after prepare - controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet - - controlnet.train() - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - # TODO: find a way to handle total batch size when there are multiple datasets - accelerator.print("running training / 学習開始") - accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print( - f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" - ) - # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm( - range(args.max_train_steps), - smoothing=0, - disable=not accelerator.is_local_main_process, - desc="steps", - ) - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - clip_sample=False, - ) - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs - ) - - loss_recorder = train_util.LossRecorder() - del train_dataset_group - - # function for saving/removing - def save_model(ckpt_name, model, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - accelerator.print(f"\nsaving checkpoint: {ckpt_file}") - - state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) - - if save_dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - if os.path.splitext(ckpt_file)[1] == ".safetensors": - from safetensors.torch import save_file - - save_file(state_dict, ckpt_file) - else: - torch.save(state_dict, ckpt_file) - - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) - - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - accelerator.print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - # For --sample_at_first - train_util.sample_images( - accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet - ) - - # training loop - for epoch in range(num_train_epochs): - if is_main_process: - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(controlnet): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) - elif args.multires_noise_iterations: - noise = pyramid_noise_like( - noise, - latents.device, - args.multires_noise_iterations, - args.multires_noise_discount, - ) - - # Sample a random timestep for each image - timesteps = train_util.get_timesteps(args, 0, noise_scheduler.config.num_train_timesteps, b_size) - huber_c = train_util.get_huber_c(args, noise_scheduler, timesteps.item(), latents.device) - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) - - with accelerator.autocast(): - down_block_res_samples, mid_block_res_sample = controlnet( - noisy_latents, - timesteps, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=controlnet_image, - return_dict=False, - ) - - # Predict the noise residual - noise_pred = unet( - noisy_latents, - timesteps, - encoder_hidden_states, - down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples], - mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), - ).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = controlnet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - train_util.sample_images( - accelerator, - args, - None, - global_step, - accelerator.device, - vae, - tokenizer, - text_encoder, - unet, - controlnet=controlnet, - ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model( - ckpt_name, - accelerator.unwrap_model(controlnet), - ) - - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) - - current_loss = loss.detach().item() - loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.moving_average - logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if args.logging_dir is not None: - logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - # 指定エポックごとにモデルを保存 - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, accelerator.unwrap_model(controlnet)) - - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) - - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - - train_util.sample_images( - accelerator, - args, - epoch + 1, - global_step, - accelerator.device, - vae, - tokenizer, - text_encoder, - unet, - controlnet=controlnet, - ) - - # end of epoch - if is_main_process: - controlnet = accelerator.unwrap_model(controlnet) - - accelerator.end_training() - - if is_main_process and (args.save_state or args.save_state_on_train_end): - train_util.save_state_on_train_end(args, accelerator) - - # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく - - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, controlnet, force_sync_upload=True) - - logger.info("model saved.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - add_logging_arguments(parser) - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, False, True, True) - train_util.add_training_arguments(parser, False) - deepspeed_utils.add_deepspeed_arguments(parser) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) - - parser.add_argument( - "--save_model_as", - type=str, - default="safetensors", - choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", - ) - parser.add_argument( - "--controlnet_model_name_or_path", - type=str, - default=None, - help="controlnet model name or path / controlnetのモデル名またはパス", - ) - parser.add_argument( - "--conditioning_data_dir", - type=str, - default=None, - help="conditioning data directory / 条件付けデータのディレクトリ", - ) - - return parser - -======= from library import train_util from train_control_net import setup_parser, train ->>>>>>> hina/feature/val-loss if __name__ == "__main__": logger.warning( From 534059dea517d44de387e7d467d64209f9dcfba2 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 01:18:15 -0500 Subject: [PATCH 288/348] Typos and lingering is_train --- library/config_util.py | 2 +- library/train_util.py | 4 ---- train_network.py | 6 +++--- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index a09d2c7ca..418c179dc 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -535,7 +535,7 @@ def print_info(_datasets): shuffle_caption: {subset.shuffle_caption} keep_tokens: {subset.keep_tokens} caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} caption_prefix: {subset.caption_prefix} caption_suffix: {subset.caption_suffix} diff --git a/library/train_util.py b/library/train_util.py index bf1b6731c..220d4702b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2092,7 +2092,6 @@ def __init__( bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: bool, - is_train: bool, validation_seed: int, validation_split: float, ) -> None: @@ -2312,7 +2311,6 @@ class ControlNetDataset(BaseDataset): def __init__( self, subsets: Sequence[ControlNetSubset], - is_train: bool, batch_size: int, resolution, network_multiplier: float, @@ -2362,7 +2360,6 @@ def __init__( self.dreambooth_dataset_delegate = DreamBoothDataset( db_subsets, - is_train, batch_size, resolution, network_multiplier, @@ -2382,7 +2379,6 @@ def __init__( self.batch_size = batch_size self.num_train_images = self.dreambooth_dataset_delegate.num_train_images self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images - self.is_train = is_train self.validation_split = validation_split self.validation_seed = validation_seed diff --git a/train_network.py b/train_network.py index 99b9717a5..4bcfc0ac7 100644 --- a/train_network.py +++ b/train_network.py @@ -380,11 +380,11 @@ def pick_timesteps_list() -> torch.IntTensor: else: return typing.cast(torch.IntTensor, torch.tensor(timesteps_list).unsqueeze(1).repeat(1, batch_size).to(latents.device)) - choosen_timesteps_list = pick_timesteps_list() + chosen_timesteps_list = pick_timesteps_list() total_loss = torch.zeros((batch_size, 1)).to(latents.device) # Use input timesteps_list or use described timesteps above - for fixed_timestep in choosen_timesteps_list: + for fixed_timestep in chosen_timesteps_list: fixed_timestep = typing.cast(torch.IntTensor, fixed_timestep) # Predict the noise residual @@ -447,7 +447,7 @@ def pick_timesteps_list() -> torch.IntTensor: total_loss += loss - return total_loss / len(choosen_timesteps_list) + return total_loss / len(chosen_timesteps_list) def train(self, args): session_id = random.randint(0, 2**32) From c8c3569df292109fe3be4d209c9f6131afe2ba5f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 01:26:45 -0500 Subject: [PATCH 289/348] Cleanup order, types, print to logger --- library/config_util.py | 7 +++---- library/train_util.py | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 418c179dc..5a4d3aa2d 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -485,7 +485,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) datasets.append(dataset) - val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: if dataset_blueprint.params.validation_split <= 0.0: continue @@ -503,7 +503,6 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) - # print info def print_info(_datasets): info = "" for i, dataset in enumerate(_datasets): @@ -565,7 +564,7 @@ def print_info(_datasets): print_info(datasets) if len(val_datasets) > 0: - print("Validation dataset") + logger.info("Validation dataset") print_info(val_datasets) if len(val_datasets) > 0: @@ -610,7 +609,7 @@ def print_info(_datasets): " ", ) - logger.info(f"{info}") + logger.info(info) # make buckets first because it determines the length of dataset # and set the same seed for all datasets diff --git a/library/train_util.py b/library/train_util.py index 220d4702b..782f57e8f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1833,9 +1833,9 @@ def __init__( bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, + debug_dataset: bool, validation_split: float, validation_seed: Optional[int], - debug_dataset, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2319,9 +2319,9 @@ def __init__( max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, + debug_dataset: bool, validation_split: float, validation_seed: Optional[int], - debug_dataset: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2369,9 +2369,9 @@ def __init__( bucket_reso_steps, bucket_no_upscale, 1.0, + debug_dataset, validation_split, validation_seed, - debug_dataset ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) From fbfc2753eb7fa57724eb525ee65d851b5e80b8ea Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 01:53:12 -0500 Subject: [PATCH 290/348] Update text for train/reg with repeats --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 782f57e8f..77a6a9f9a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2050,11 +2050,11 @@ def load_dreambooth_dir(subset: DreamBoothSubset): subset.img_count = len(img_paths) self.subsets.append(subset) - logger.info(f"{num_train_images} images with repeating.") + logger.info(f"{num_train_images} train images with repeats.") self.num_train_images = num_train_images - logger.info(f"{num_reg_images} reg images.") + logger.info(f"{num_reg_images} reg images with repeats.") if num_train_images < num_reg_images: logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") From 58bfa36d0275d864d5a2d64c51632e808f789ddd Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 02:00:28 -0500 Subject: [PATCH 291/348] Add seed help clarifying info --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 4bcfc0ac7..7d064d210 100644 --- a/train_network.py +++ b/train_network.py @@ -1639,7 +1639,7 @@ def setup_parser() -> argparse.ArgumentParser: "--validation_seed", type=int, default=None, - help="Validation seed / 検証シード" + help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証シード" ) parser.add_argument( "--validation_split", From 6604b36044a83f3531faed508096f3e6bfe48fc9 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 02:04:59 -0500 Subject: [PATCH 292/348] Remove duplicate assignment --- library/train_util.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 77a6a9f9a..3710c865d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -86,8 +86,6 @@ import library.deepspeed_utils as deepspeed_utils from library.utils import setup_logging, pil_resize - - setup_logging() import logging @@ -1841,8 +1839,6 @@ def __init__( assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" - self.validation_split = validation_split - self.validation_seed = validation_seed self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight From 0522070d197d92745dbdb408d74c9c3f869bff76 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 15:20:25 -0500 Subject: [PATCH 293/348] Fix training, validation split, revert to using upstream implemenation --- library/config_util.py | 67 +++----------- library/custom_train_functions.py | 6 +- library/strategy_sd.py | 2 +- library/train_util.py | 143 +++++++++++++++++------------- train_network.py | 94 ++++++++++++-------- 5 files changed, 152 insertions(+), 160 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 5a4d3aa2d..63d28c969 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -482,7 +482,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] @@ -500,16 +500,16 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_training_dataset=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) - def print_info(_datasets): + def print_info(_datasets, dataset_type: str): info = "" for i, dataset in enumerate(_datasets): is_dreambooth = isinstance(dataset, DreamBoothDataset) is_controlnet = isinstance(dataset, ControlNetDataset) info += dedent(f"""\ - [Dataset {i}] + [{dataset_type} {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} enable_bucket: {dataset.enable_bucket} @@ -527,7 +527,7 @@ def print_info(_datasets): for j, subset in enumerate(dataset.subsets): info += indent(dedent(f"""\ - [Subset {j} of Dataset {i}] + [Subset {j} of {dataset_type} {i}] image_dir: "{subset.image_dir}" image_count: {subset.img_count} num_repeats: {subset.num_repeats} @@ -544,8 +544,8 @@ def print_info(_datasets): random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, - alpha_mask: {subset.alpha_mask} - custom_attributes: {subset.custom_attributes} + alpha_mask: {subset.alpha_mask} + custom_attributes: {subset.custom_attributes} """), " ") if is_dreambooth: @@ -561,67 +561,22 @@ def print_info(_datasets): logger.info(info) - print_info(datasets) + print_info(datasets, "Dataset") if len(val_datasets) > 0: - logger.info("Validation dataset") - print_info(val_datasets) - - if len(val_datasets) > 0: - info = "" - - for i, dataset in enumerate(val_datasets): - info += dedent( - f"""\ - [Validation Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - network_multiplier: {dataset.network_multiplier} - """ - ) - - if dataset.enable_bucket: - info += indent( - dedent( - f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n""" - ), - " ", - ) - else: - info += "\n" - - for j, subset in enumerate(dataset.subsets): - info += indent( - dedent( - f"""\ - [Subset {j} of Validation Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - """ - ), - " ", - ) - - logger.info(info) + print_info(val_datasets, "Validation Dataset") # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): - logger.info(f"[Dataset {i}]") + logger.info(f"[Prepare dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) for i, dataset in enumerate(val_datasets): - logger.info(f"[Validation Dataset {i}]") + logger.info(f"[Prepare validation dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 9a7c21a3e..ad3e69ffb 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -455,7 +455,7 @@ def get_weighted_text_embeddings( # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 -def pyramid_noise_like(noise, device, iterations=6, discount=0.4): +def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor: b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant! u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device) for i in range(iterations): @@ -468,7 +468,7 @@ def pyramid_noise_like(noise, device, iterations=6, discount=0.4): # https://www.crosslabs.org//blog/diffusion-with-offset-noise -def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): +def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor: if noise_offset is None: return noise if adaptive_noise_scale is not None: @@ -484,7 +484,7 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise -def apply_masked_loss(loss, batch): +def apply_masked_loss(loss, batch) -> torch.FloatTensor: if "conditioning_images" in batch: # conditioning image is -1 to 1. we need to convert it to 0 to 1 mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel diff --git a/library/strategy_sd.py b/library/strategy_sd.py index d0a3a68bf..a44fc4092 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -40,7 +40,7 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: text = [text] if isinstance(text, str) else text return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] - def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: text = [text] if isinstance(text, str) else text tokens_list = [] weights_list = [] diff --git a/library/train_util.py b/library/train_util.py index 3710c865d..0f16a4f31 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -146,7 +146,15 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" -def split_train_val(paths: List[str], is_train: bool, validation_split: float, validation_seed: int) -> List[str]: +def split_train_val(paths: List[str], is_training_dataset: bool, validation_split: float, validation_seed: int) -> List[str]: + """ + Split the dataset into train and validation + + Shuffle the dataset based on the validation_seed or the current random seed. + For example if the split of 0.2 of 100 images. + [0:79] = 80 training images + [80:] = 20 validation images + """ if validation_seed is not None: print(f"Using validation seed: {validation_seed}") prevstate = random.getstate() @@ -156,9 +164,12 @@ def split_train_val(paths: List[str], is_train: bool, validation_split: float, v else: random.shuffle(paths) - if is_train: + # Split the dataset between training and validation + if is_training_dataset: + # Training dataset we split to the first part return paths[0:math.ceil(len(paths) * (1 - validation_split))] else: + # Validation dataset we split to the second part return paths[len(paths) - round(len(paths) * validation_split):] @@ -1822,6 +1833,7 @@ class DreamBoothDataset(BaseDataset): def __init__( self, subsets: Sequence[DreamBoothSubset], + is_training_dataset: bool, batch_size: int, resolution, network_multiplier: float, @@ -1843,6 +1855,7 @@ def __init__( self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight self.latents_cache = None + self.is_training_dataset = is_training_dataset self.validation_seed = validation_seed self.validation_split = validation_split @@ -1952,6 +1965,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset): size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + if self.validation_split > 0.0: + img_paths = split_train_val(img_paths, self.is_training_dataset, self.validation_split, self.validation_seed) + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") if use_cached_info_for_subset: @@ -2046,7 +2062,8 @@ def load_dreambooth_dir(subset: DreamBoothSubset): subset.img_count = len(img_paths) self.subsets.append(subset) - logger.info(f"{num_train_images} train images with repeats.") + images_split_name = "train" if self.is_training_dataset else "validation" + logger.info(f"{num_train_images} {images_split_name} images with repeats.") self.num_train_images = num_train_images @@ -2411,8 +2428,12 @@ def __init__( conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) - assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" - assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" + assert ( + len(missing_imgs) == 0 + ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" + assert ( + len(extra_imgs) == 0 + ), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" self.conditioning_image_transforms = IMAGE_TRANSFORMS @@ -4586,7 +4607,6 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar config_args = argparse.Namespace(**ignore_nesting_dict) args = parser.parse_args(namespace=config_args) args.config_file = os.path.splitext(args.config_file)[0] - logger.info(args.config_file) return args @@ -5880,55 +5900,35 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_random_timesteps(args, min_timestep: int, max_timestep: int, batch_size: int, device: torch.device) -> torch.IntTensor: - """ - Get a random timestep between the min and max timesteps - Can error (NotImplementedError) if the loss type is not supported - """ - # TODO: if a huber loss is selected, it will use constant timesteps for each batch - # as. In the future there may be a smarter way - if args.loss_type == "huber" or args.loss_type == "smooth_l1": - timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu") - timesteps = timesteps.repeat(batch_size).to(device) - elif args.loss_type == "l2": - timesteps = torch.randint(min_timestep, max_timestep, (batch_size,), device=device) - else: - raise NotImplementedError(f"Unknown loss type {args.loss_type}") - - return typing.cast(torch.IntTensor, timesteps) - +def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.IntTensor: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) + return timesteps -def get_huber_c(args, noise_scheduler: DDPMScheduler, timesteps: torch.IntTensor) -> Optional[float]: - """ - Calculate the Huber convolution (huber_c) value - Huber loss is a loss function used in robust regression, that is less sensitive - to outliers in data than the squared error loss. - https://en.wikipedia.org/wiki/Huber_loss - """ - if args.loss_type == "huber" or args.loss_type == "smooth_l1": - if args.huber_schedule == "exponential": - alpha = -math.log(args.huber_c) / noise_scheduler.config.get('num_train_timesteps', 1000) - huber_c = math.exp(-alpha * timesteps.item()) - elif args.huber_schedule == "snr": - if not hasattr(noise_scheduler, "alphas_cumprod"): - raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") - alphas_cumprod = noise_scheduler.alphas_cumprod.index_select(0, timesteps) - sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 - huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c - elif args.huber_schedule == "constant": - huber_c = args.huber_c - else: - raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - elif args.loss_type == "l2": +def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: + if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): return None + + b_size = timesteps.shape[0] + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + result = torch.exp(-alpha * timesteps) * args.huber_scale + elif args.huber_schedule == "snr": + if not hasattr(noise_scheduler, "alphas_cumprod"): + raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + result = result.to(timesteps.device) + elif args.huber_schedule == "constant": + result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) else: - raise NotImplementedError(f"Unknown loss type {args.loss_type}") + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - return huber_c + return result -def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor): +def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor: """ Apply noise modifications like noise offset and multires noise """ @@ -5964,27 +5964,44 @@ def make_random_timesteps(args, noise_scheduler: DDPMScheduler, batch_size: int, max_timestep = noise_scheduler.config.get('num_train_timesteps', 1000) if args.max_timestep is None else args.max_timestep # Sample a random timestep for each image - timesteps = get_random_timesteps(args, min_timestep, max_timestep, batch_size, device) + timesteps = get_timesteps(min_timestep, max_timestep, batch_size, device) return timesteps -def get_noise_noisy_latents_and_timesteps(args, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor, Optional[float]]: - """ - Unified noise, noisy_latents, timesteps and huber loss convolution calculations - """ - batch_size = latents.shape[0] +def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + if args.noise_offset_random_strength: + noise_offset = torch.rand(1, device=latents.device) * args.noise_offset + else: + noise_offset = args.noise_offset + noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale) + if args.multires_noise_iterations: + noise = custom_train_functions.pyramid_noise_like( + noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount + ) + + # Sample a random timestep for each image + b_size = latents.shape[0] min_timestep = 0 if args.min_timestep is None else args.min_timestep - max_timestep = noise_scheduler.config.get("num_train_timesteps", 1000) if args.max_timestep is None else args.max_timestep + max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - # A random timestep for each image in the batch - timesteps = get_random_timesteps(args, min_timestep, max_timestep, batch_size, latents.device) - huber_c = get_huber_c(args, noise_scheduler, timesteps) + timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device) - noise = make_noise(args, latents) - noisy_latents = get_noisy_latents(args, noise, noise_scheduler, latents, timesteps) + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma + else: + strength = args.ip_noise_gamma + noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - return noise, noisy_latents, timesteps, huber_c + return noise, noisy_latents, timesteps def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor: @@ -6015,6 +6032,8 @@ def conditional_loss( elif loss_type == "l1": loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) elif loss_type == "huber": + if huber_c is None: + raise NotImplementedError("huber_c not implemented correctly") huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": @@ -6022,6 +6041,8 @@ def conditional_loss( elif reduction == "sum": loss = torch.sum(loss) elif loss_type == "smooth_l1": + if huber_c is None: + raise NotImplementedError("huber_c not implemented correctly") huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": diff --git a/train_network.py b/train_network.py index 7d064d210..f870734fd 100644 --- a/train_network.py +++ b/train_network.py @@ -205,10 +205,10 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae: AutoencoderKL, images: torch.FloatTensor) -> torch.FloatTensor: return vae.encode(images).latent_dist.sample() - def shift_scale_latents(self, args, latents): + def shift_scale_latents(self, args, latents: torch.FloatTensor) -> torch.FloatTensor: return latents * self.vae_scale_factor def get_noise_pred_and_target( @@ -280,7 +280,7 @@ def get_noise_pred_and_target( return noise_pred, target, timesteps, None - def post_process_loss(self, loss, args, timesteps, noise_scheduler): + def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: @@ -317,20 +317,21 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, # endregion - def process_batch(self, batch, tokenizers, text_encoders, unet, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: + def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents: torch.Tensor = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) + latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) else: # latentに変換 - latents: torch.Tensor = typing.cast(torch.FloatTensor, typing.cast(AutoencoderKLOutput, vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype))).latent_dist.sample()) + latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = typing.cast(torch.FloatTensor, torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)) - latents = typing.cast(torch.FloatTensor, latents * self.vae_scale_factor) + latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents)) + + latents = self.shift_scale_latents(args, latents) text_encoder_conds = [] @@ -384,22 +385,36 @@ def pick_timesteps_list() -> torch.IntTensor: total_loss = torch.zeros((batch_size, 1)).to(latents.device) # Use input timesteps_list or use described timesteps above - for fixed_timestep in chosen_timesteps_list: - fixed_timestep = typing.cast(torch.IntTensor, fixed_timestep) + for fixed_timesteps in chosen_timesteps_list: + fixed_timesteps = typing.cast(torch.IntTensor, fixed_timesteps) # Predict the noise residual # and add noise to the latents # with noise offset and/or multires noise if specified - noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timestep) + noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timesteps) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents.requires_grad_(train_unet), fixed_timestep, text_encoder_conds, batch, weight_dtype + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + fixed_timesteps, + text_encoder_conds, + batch, + weight_dtype, ) if args.v_parameterization: # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, fixed_timestep) + target = noise_scheduler.get_velocity(latents, noise, fixed_timesteps) else: target = noise @@ -418,7 +433,7 @@ def pick_timesteps_list() -> torch.IntTensor: accelerator, unet, noisy_latents, - timesteps, + fixed_timesteps, text_encoder_conds, batch, weight_dtype, @@ -427,7 +442,8 @@ def pick_timesteps_list() -> torch.IntTensor: network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + huber_c = train_util.get_huber_threshold_if_needed(args, fixed_timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) # 平均なのでbatch_sizeで割る必要なし if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): @@ -436,14 +452,7 @@ def pick_timesteps_list() -> torch.IntTensor: loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight loss = loss * loss_weights - if args.min_snr_gamma: - loss = apply_snr_weight(loss, fixed_timestep, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, fixed_timestep, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, fixed_timestep, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, fixed_timestep, noise_scheduler) + loss = self.post_process_loss(loss, args, fixed_timesteps, noise_scheduler) total_loss += loss @@ -526,8 +535,12 @@ def train(self, args): collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: - train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly + train_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) + + if val_dataset_group is not None: + val_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly + train_util.debug_dataset(val_dataset_group) return if len(train_dataset_group) == 0: logger.error( @@ -753,10 +766,6 @@ def train(self, args): # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) - # Not for sure here. - # if val_dataset_group is not None: - # val_dataset_group.set_max_train_steps(args.max_train_steps) - # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -1304,7 +1313,7 @@ def remove_model(old_ckpt_name): clean_memory_on_device(accelerator.device) for epoch in range(epoch_to_start, num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1) @@ -1324,7 +1333,7 @@ def remove_model(old_ckpt_name): continue with accelerator.accumulate(training_model): - loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) accelerator.backward(loss) if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually @@ -1384,7 +1393,8 @@ def remove_model(old_ckpt_name): logs = self.generate_step_logs( args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) - accelerator.log(logs, step=global_step) + # accelerator.log(logs, step=global_step) + accelerator.log(logs) # VALIDATION PER STEP should_validate = (args.validation_every_n_step is not None @@ -1401,7 +1411,7 @@ def remove_model(old_ckpt_name): if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) val_progress_bar.update(1) @@ -1409,10 +1419,12 @@ def remove_model(old_ckpt_name): if is_tracking: logs = {"loss/current_val_loss": loss.detach().item()} - accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + accelerator.log(logs) logs = {"loss/average_val_loss": val_loss_recorder.moving_average} - accelerator.log(logs, step=global_step) + # accelerator.log(logs, step=global_step) + accelerator.log(logs) if global_step >= args.max_train_steps: break @@ -1427,7 +1439,7 @@ def remove_model(old_ckpt_name): ) for val_step, batch in enumerate(val_dataloader): - loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) @@ -1437,22 +1449,26 @@ def remove_model(old_ckpt_name): if is_tracking: avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_current": current_loss} - accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + accelerator.log(logs) if is_tracking: avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_average": avr_loss} - accelerator.log(logs, step=epoch + 1) + # accelerator.log(logs, step=epoch + 1) + accelerator.log(logs) # END OF EPOCH if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) + # accelerator.log(logs, step=epoch + 1) + accelerator.log(logs) if len(val_dataloader) > 0 and is_tracking: avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_epoch_average": avr_loss} - accelerator.log(logs, step=epoch + 1) + # accelerator.log(logs, step=epoch + 1) + accelerator.log(logs) accelerator.wait_for_everyone() From 695f38962ce279adfee3fabb3479b84b1076b4e8 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 15:25:12 -0500 Subject: [PATCH 294/348] Move get_huber_threshold_if_needed --- library/train_util.py | 44 ++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0f16a4f31..0907a8c03 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5905,27 +5905,6 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor return timesteps -def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: - if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): - return None - - b_size = timesteps.shape[0] - if args.huber_schedule == "exponential": - alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps - result = torch.exp(-alpha * timesteps) * args.huber_scale - elif args.huber_schedule == "snr": - if not hasattr(noise_scheduler, "alphas_cumprod"): - raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") - alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) - sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 - result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c - result = result.to(timesteps.device) - elif args.huber_schedule == "constant": - result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) - else: - raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - - return result def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor: @@ -6004,6 +5983,29 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch. return noise, noisy_latents, timesteps +def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: + if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): + return None + + b_size = timesteps.shape[0] + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + result = torch.exp(-alpha * timesteps) * args.huber_scale + elif args.huber_schedule == "snr": + if not hasattr(noise_scheduler, "alphas_cumprod"): + raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + result = result.to(timesteps.device) + elif args.huber_schedule == "constant": + result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) + else: + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") + + return result + + def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor: """ Add noise to the latents according to the noise magnitude at each timestep From 1f9ba40b8b70fd08e6b87a70727d5e789666a925 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 15:32:07 -0500 Subject: [PATCH 295/348] Add step break for validation epoch. Remove unused variable --- train_network.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index f870734fd..ce34f26d3 100644 --- a/train_network.py +++ b/train_network.py @@ -1439,6 +1439,9 @@ def remove_model(old_ckpt_name): ) for val_step, batch in enumerate(val_dataloader): + if val_step >= validation_steps: + break + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) current_loss = loss.detach().item() @@ -1447,7 +1450,6 @@ def remove_model(old_ckpt_name): val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) if is_tracking: - avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_current": current_loss} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) accelerator.log(logs) From 1c0ae306e551ede5bd162819debb4d80a7fe620b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 15:43:02 -0500 Subject: [PATCH 296/348] Add missing functions for training batch --- train_network.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index ce34f26d3..377ddf48e 100644 --- a/train_network.py +++ b/train_network.py @@ -318,7 +318,7 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, # endregion def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: - + with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) @@ -1333,6 +1333,11 @@ def remove_model(old_ckpt_name): continue with accelerator.accumulate(training_model): + on_step_start_for_network(text_encoder, unet) + + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) accelerator.backward(loss) if accelerator.sync_gradients: From bbf6bbd5ea27231066cec98b8bf2a65f162cb18f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 10:48:38 -0500 Subject: [PATCH 297/348] Use self.get_noise_pred_and_target and drop fixed timesteps --- flux_train_network.py | 7 ++- sd3_train_network.py | 3 +- train_network.py | 116 ++++++++++++------------------------------ 3 files changed, 40 insertions(+), 86 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 75e975bae..b3aebecc7 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -339,6 +339,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, + is_train=True ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -375,7 +376,7 @@ def get_noise_pred_and_target( def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): # if not args.split_mode: # normal forward - with accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( img=img, @@ -420,7 +421,9 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t intermediate_txt.requires_grad_(True) vec.requires_grad_(True) pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + + with torch.set_grad_enabled(is_train and train_unet): + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) """ return model_pred diff --git a/sd3_train_network.py b/sd3_train_network.py index fb7711bda..c7417802d 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -312,6 +312,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, + is_train=True ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -339,7 +340,7 @@ def get_noise_pred_and_target( t5_attn_mask = None # call model - with accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): # TODO support attention mask model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) diff --git a/train_network.py b/train_network.py index 377ddf48e..61e6369ae 100644 --- a/train_network.py +++ b/train_network.py @@ -223,6 +223,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, + is_train=True ): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -236,7 +237,7 @@ def get_noise_pred_and_target( t.requires_grad_(True) # Predict the noise residual - with accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, @@ -317,7 +318,7 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, # endregion - def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: + def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True) -> torch.Tensor: with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: @@ -372,91 +373,40 @@ def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: Au batch_size = latents.shape[0] - # Sample noise, - noise = train_util.make_noise(args, latents) - def pick_timesteps_list() -> torch.IntTensor: - if timesteps_list is None or timesteps_list == []: - return typing.cast(torch.IntTensor, train_util.make_random_timesteps(args, noise_scheduler, batch_size, latents.device).unsqueeze(1)) - else: - return typing.cast(torch.IntTensor, torch.tensor(timesteps_list).unsqueeze(1).repeat(1, batch_size).to(latents.device)) - - chosen_timesteps_list = pick_timesteps_list() - total_loss = torch.zeros((batch_size, 1)).to(latents.device) - - # Use input timesteps_list or use described timesteps above - for fixed_timesteps in chosen_timesteps_list: - fixed_timesteps = typing.cast(torch.IntTensor, fixed_timesteps) - - # Predict the noise residual - # and add noise to the latents - # with noise offset and/or multires noise if specified - noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timesteps) - - # ensure the hidden state will require grad - if args.gradient_checkpointing: - for x in noisy_latents: - x.requires_grad_(True) - for t in text_encoder_conds: - t.requires_grad_(True) - - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): - noise_pred = self.call_unet( - args, - accelerator, - unet, - noisy_latents.requires_grad_(train_unet), - fixed_timesteps, - text_encoder_conds, - batch, - weight_dtype, - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, fixed_timesteps) - else: - target = noise - - # differential output preservation - if "custom_attributes" in batch: - diff_output_pr_indices = [] - for i, custom_attributes in enumerate(batch["custom_attributes"]): - if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: - diff_output_pr_indices.append(i) - - if len(diff_output_pr_indices) > 0: - network.set_multiplier(0.0) - with torch.no_grad(), accelerator.autocast(): - noise_pred_prior = self.call_unet( - args, - accelerator, - unet, - noisy_latents, - fixed_timesteps, - text_encoder_conds, - batch, - weight_dtype, - indices=diff_output_pr_indices, - ) - network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step - target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - - huber_c = train_util.get_huber_threshold_if_needed(args, fixed_timesteps, noise_scheduler) - loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) - loss = loss.mean([1, 2, 3]) # 平均なのでbatch_sizeで割る必要なし + # Predict the noise residual + # and add noise to the latents + # with noise offset and/or multires noise if specified - if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - loss = apply_masked_loss(loss, batch) + # sample noise, call unet, get target + noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + is_train=is_train + ) - loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight - loss = loss * loss_weights + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) - loss = self.post_process_loss(loss, args, fixed_timesteps, noise_scheduler) + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights - total_loss += loss + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - return total_loss / len(chosen_timesteps_list) + return loss.mean() def train(self, args): session_id = random.randint(0, 2**32) @@ -1416,7 +1366,7 @@ def remove_model(old_ckpt_name): if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False) val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) val_progress_bar.update(1) @@ -1447,7 +1397,7 @@ def remove_model(old_ckpt_name): if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) From f4840ef29ef67878d7c7ccec92bdce89c3b61c6d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 10:52:07 -0500 Subject: [PATCH 298/348] Revert train_db.py --- train_db.py | 121 ++-------------------------------------------------- 1 file changed, 3 insertions(+), 118 deletions(-) diff --git a/train_db.py b/train_db.py index 398489ffe..ad21f8d1b 100644 --- a/train_db.py +++ b/train_db.py @@ -2,6 +2,7 @@ # XXX dropped option: fine_tune import argparse +import itertools import math import os from multiprocessing import Value @@ -41,73 +42,11 @@ setup_logging() import logging -import itertools logger = logging.getLogger(__name__) # perlin_noise, -def process_val_batch(*training_models, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args): - total_loss = 0.0 - timesteps_list = [10, 350, 500, 650, 990] - - with accelerator.accumulate(*training_models): - with torch.no_grad(): - # latentに変換 - if cache_latents: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) - else: - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - with torch.set_grad_enabled(False), accelerator.autocast(): - if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - - for fixed_timesteps in timesteps_list: - with torch.set_grad_enabled(False), accelerator.autocast(): - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise = torch.randn_like(latents, device=latents.device) - b_size = latents.shape[0] - timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - if args.masked_loss: - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - total_loss += loss - - average_loss = total_loss / len(timesteps_list) - return average_loss def train(args): train_util.verify_training_args(args) @@ -150,10 +89,9 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) - val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -274,15 +212,6 @@ def train(args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) - val_dataloader = torch.utils.data.DataLoader( - val_dataset_group if val_dataset_group is not None else [], - shuffle=False, - batch_size=1, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - cyclic_val_dataloader = itertools.cycle(val_dataloader) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -393,8 +322,6 @@ def train(args): accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() - val_loss_recorder = train_util.LossRecorder() - for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -525,25 +452,6 @@ def train(args): avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break @@ -634,30 +542,7 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) - parser.add_argument( - "--validation_seed", - type=int, - default=None, - help="Validation seed" - ) - parser.add_argument( - "--validation_split", - type=float, - default=0.0, - help="Split for validation images out of the training dataset" - ) - parser.add_argument( - "--validation_every_n_step", - type=int, - default=None, - help="Number of train steps for counting validation loss. By default, validation per train epoch is performed" - ) - parser.add_argument( - "--max_validation_steps", - type=int, - default=None, - help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset" - ) + return parser From 1c63e7cc4979b528417b5bfe181e0a9ac119209c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 11:07:47 -0500 Subject: [PATCH 299/348] Cleanup unused code and formatting --- train_network.py | 85 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 70 insertions(+), 15 deletions(-) diff --git a/train_network.py b/train_network.py index 61e6369ae..5a80d825d 100644 --- a/train_network.py +++ b/train_network.py @@ -318,8 +318,27 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, # endregion - def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True) -> torch.Tensor: - + def process_batch( + self, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, + tokenize_strategy: strategy_sd.SdTokenizeStrategy, + is_train=True, + train_text_encoder=True, + train_unet=True + ) -> torch.Tensor: + """ + Process a batch for the network + """ with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) @@ -334,7 +353,6 @@ def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: Au latents = self.shift_scale_latents(args, latents) - text_encoder_conds = [] text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: @@ -371,13 +389,6 @@ def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: Au if encoded_text_encoder_conds[i] is not None: text_encoder_conds[i] = encoded_text_encoder_conds[i] - batch_size = latents.shape[0] - - - # Predict the noise residual - # and add noise to the latents - # with noise offset and/or multires noise if specified - # sample noise, call unet, get target noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( args, @@ -1288,7 +1299,23 @@ def remove_model(old_ckpt_name): # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) + loss = self.process_batch(batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=True, + train_text_encoder=train_text_encoder, + train_unet=train_unet + ) + accelerator.backward(loss) if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually @@ -1366,12 +1393,26 @@ def remove_model(old_ckpt_name): if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False) - + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False + ) + val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) val_progress_bar.update(1) val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) - + if is_tracking: logs = {"loss/current_val_loss": loss.detach().item()} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) @@ -1397,7 +1438,21 @@ def remove_model(old_ckpt_name): if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False) + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False + ) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) From c64d1a22fc4ff25625873e50d63d480b297301c6 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 11:30:21 -0500 Subject: [PATCH 300/348] Add validate_every_n_epochs, change name validate_every_n_steps --- train_network.py | 69 ++++++++++++++++++++++++++++++------------------ 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/train_network.py b/train_network.py index 5a80d825d..f3c8d8c96 100644 --- a/train_network.py +++ b/train_network.py @@ -1199,7 +1199,8 @@ def load_model_hook(models, input_dir): ) loss_recorder = train_util.LossRecorder() - val_loss_recorder = train_util.LossRecorder() + val_step_loss_recorder = train_util.LossRecorder() + val_epoch_loss_recorder = train_util.LossRecorder() del train_dataset_group if val_dataset_group is not None: @@ -1299,7 +1300,8 @@ def remove_model(old_ckpt_name): # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - loss = self.process_batch(batch, + loss = self.process_batch( + batch, text_encoders, unet, network, @@ -1373,15 +1375,25 @@ def remove_model(old_ckpt_name): if is_tracking: logs = self.generate_step_logs( - args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm + args, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + optimizer, + keys_scaled, + mean_norm, + maximum_norm ) # accelerator.log(logs, step=global_step) accelerator.log(logs) # VALIDATION PER STEP - should_validate = (args.validation_every_n_step is not None - and global_step % args.validation_every_n_step == 0) - if validation_steps > 0 and should_validate: + should_validate_epoch = ( + args.validate_every_n_steps is not None + and global_step % args.validate_every_n_steps == 0 + ) + if validation_steps > 0 and should_validate_epoch: accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( @@ -1409,16 +1421,17 @@ def remove_model(old_ckpt_name): is_train=False ) - val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) + current_loss = loss.detach().item() + val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) + val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) if is_tracking: - logs = {"loss/current_val_loss": loss.detach().item()} + logs = {"loss/step_validation_current": current_loss} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) accelerator.log(logs) - logs = {"loss/average_val_loss": val_loss_recorder.moving_average} + logs = {"loss/step_validation_average": val_step_loss_recorder.moving_average} # accelerator.log(logs, step=global_step) accelerator.log(logs) @@ -1426,12 +1439,18 @@ def remove_model(old_ckpt_name): break # VALIDATION EPOCH - if len(val_dataloader) > 0: + should_validate_epoch = ( + (epoch + 1) % args.validate_every_n_epochs == 0 + if args.validate_every_n_epochs is not None + else False + ) + + if should_validate_epoch and len(val_dataloader) > 0: accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, - desc="validation steps" + desc="epoch validation steps" ) for val_step, batch in enumerate(val_dataloader): @@ -1455,18 +1474,18 @@ def remove_model(old_ckpt_name): ) current_loss = loss.detach().item() - val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) + val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) if is_tracking: - logs = {"loss/validation_current": current_loss} + logs = {"loss/epoch_validation_current": current_loss} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) accelerator.log(logs) if is_tracking: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation_average": avr_loss} + avr_loss: float = val_epoch_loss_recorder.moving_average + logs = {"loss/epoch_validation_average": avr_loss} # accelerator.log(logs, step=epoch + 1) accelerator.log(logs) @@ -1475,12 +1494,6 @@ def remove_model(old_ckpt_name): logs = {"loss/epoch_average": loss_recorder.moving_average} # accelerator.log(logs, step=epoch + 1) accelerator.log(logs) - - if len(val_dataloader) > 0 and is_tracking: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation_epoch_average": avr_loss} - # accelerator.log(logs, step=epoch + 1) - accelerator.log(logs) accelerator.wait_for_everyone() @@ -1676,10 +1689,16 @@ def setup_parser() -> argparse.ArgumentParser: help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" ) parser.add_argument( - "--validation_every_n_step", + "--validate_every_n_steps", + type=int, + default=None, + help="Run validation dataset every N steps" + ) + parser.add_argument( + "--validate_every_n_epochs", type=int, default=None, - help="Number of train steps for counting validation loss. By default, validation per train epoch is performed / 学習エポックごとに検証を行う場合はNoneを指定する" + help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available" ) parser.add_argument( "--max_validation_steps", From f8850296c83ef2091bf1cb0f6e9ba462adfd9045 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 11:34:10 -0500 Subject: [PATCH 301/348] Fix validate epoch, cleanup imports --- train_network.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/train_network.py b/train_network.py index f3c8d8c96..11bba71e8 100644 --- a/train_network.py +++ b/train_network.py @@ -3,15 +3,13 @@ import math import os import typing -from typing import List, Optional, Union +from typing import Any, List import sys import random import time import json from multiprocessing import Value -from typing import Any, List import toml -import itertools from tqdm import tqdm @@ -23,8 +21,8 @@ from accelerate.utils import set_seed from accelerate import Accelerator -from diffusers import DDPMScheduler, AutoencoderKL -from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers import DDPMScheduler +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from library import deepspeed_utils, model_util, strategy_base, strategy_sd import library.train_util as train_util @@ -49,7 +47,6 @@ setup_logging() import logging -import itertools logger = logging.getLogger(__name__) @@ -1442,7 +1439,7 @@ def remove_model(old_ckpt_name): should_validate_epoch = ( (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None - else False + else True ) if should_validate_epoch and len(val_dataloader) > 0: From fcb2ff010cf2e42c50b3745a17317f2d4b4319d9 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 11:39:32 -0500 Subject: [PATCH 302/348] Clean up some validation help documentation --- train_network.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index 11bba71e8..af180c455 100644 --- a/train_network.py +++ b/train_network.py @@ -1677,7 +1677,7 @@ def setup_parser() -> argparse.ArgumentParser: "--validation_seed", type=int, default=None, - help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証シード" + help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する" ) parser.add_argument( "--validation_split", @@ -1689,19 +1689,19 @@ def setup_parser() -> argparse.ArgumentParser: "--validate_every_n_steps", type=int, default=None, - help="Run validation dataset every N steps" + help="Run validation on validation dataset every N steps if a validation dataset is available / 検証データセットが利用可能な場合は、Nステップごとに検証データセットの検証を実行します" ) parser.add_argument( "--validate_every_n_epochs", type=int, default=None, - help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available" + help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます" ) parser.add_argument( "--max_validation_steps", type=int, default=None, - help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset / 検証データセット全体を検証する場合はNoneを指定する" + help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します" ) return parser From 742bee9738e9d190a39f5a36adf4515fa415e9b7 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 17:34:23 -0500 Subject: [PATCH 303/348] Set validation steps in multiple lines for readability --- train_network.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index af180c455..d0596fcae 100644 --- a/train_network.py +++ b/train_network.py @@ -1251,7 +1251,11 @@ def remove_model(old_ckpt_name): # log empty object to commit the sample images to wandb accelerator.log({}, step=0) - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + validation_steps = ( + min(args.max_validation_steps, len(val_dataloader)) + if args.max_validation_steps is not None + else len(val_dataloader) + ) # training loop if initial_step > 0: # only if skip_until_initial_step is specified @@ -1689,7 +1693,7 @@ def setup_parser() -> argparse.ArgumentParser: "--validate_every_n_steps", type=int, default=None, - help="Run validation on validation dataset every N steps if a validation dataset is available / 検証データセットが利用可能な場合は、Nステップごとに検証データセットの検証を実行します" + help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます" ) parser.add_argument( "--validate_every_n_epochs", From 1231f5114ccd6a0a26a53da82b89083299ccc333 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 7 Jan 2025 22:31:41 -0500 Subject: [PATCH 304/348] Remove unused train_util code, fix accelerate.log for wandb, add init_trackers library code --- library/train_util.py | 70 ++++++++++++++++--------------------------- train_network.py | 66 ++++++++++++++++++++-------------------- 2 files changed, 59 insertions(+), 77 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0907a8c03..b8894752e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5900,51 +5900,9 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.IntTensor: +def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.Tensor: timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) - return timesteps - - - - -def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor: - """ - Apply noise modifications like noise offset and multires noise - """ - if args.noise_offset: - if args.noise_offset_random_strength: - noise_offset = torch.rand(1, device=latents.device) * args.noise_offset - else: - noise_offset = args.noise_offset - noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale) - if args.multires_noise_iterations: - noise = custom_train_functions.pyramid_noise_like( - noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount - ) - return noise - - -def make_noise(args, latents: torch.Tensor) -> torch.FloatTensor: - """ - Make a noise tensor to denoise and apply noise modifications (noise offset, multires noise). See `modify_noise` - """ - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - noise = modify_noise(args, noise, latents) - - return typing.cast(torch.FloatTensor, noise) - - -def make_random_timesteps(args, noise_scheduler: DDPMScheduler, batch_size: int, device: torch.device) -> torch.IntTensor: - """ - From args, produce random timesteps for each image in the batch - """ - min_timestep = 0 if args.min_timestep is None else args.min_timestep - max_timestep = noise_scheduler.config.get('num_train_timesteps', 1000) if args.max_timestep is None else args.max_timestep - - # Sample a random timestep for each image - timesteps = get_timesteps(min_timestep, max_timestep, batch_size, device) - + timesteps = timesteps.long().to(device) return timesteps @@ -6457,6 +6415,30 @@ def sample_image_inference( wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption +def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): + """ + Initialize experiment trackers with tracker specific behaviors + """ + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + default_tracker_name if args.log_tracker_name is None else args.log_tracker_name, + config=get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + import wandb + wandb_tracker = accelerator.get_tracker("wandb", unwrap=True) + + # Define specific metrics to handle validation and epochs "steps" + wandb_tracker.define_metric("epoch", hidden=True) + wandb_tracker.define_metric("val_step", hidden=True) + # endregion diff --git a/train_network.py b/train_network.py index d0596fcae..199f589b0 100644 --- a/train_network.py +++ b/train_network.py @@ -327,8 +327,8 @@ def process_batch( weight_dtype, accelerator, args, - text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, - tokenize_strategy: strategy_sd.SdTokenizeStrategy, + text_encoding_strategy: strategy_base.TextEncodingStrategy, + tokenize_strategy: strategy_base.TokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True @@ -1183,17 +1183,7 @@ def load_model_hook(models, input_dir): noise_scheduler = self.get_noise_scheduler(args, accelerator.device) - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, - config=train_util.get_sanitized_config_or_none(args), - init_kwargs=init_kwargs, - ) + train_util.init_trackers(accelerator, args, "network_train") loss_recorder = train_util.LossRecorder() val_step_loss_recorder = train_util.LossRecorder() @@ -1386,15 +1376,14 @@ def remove_model(old_ckpt_name): mean_norm, maximum_norm ) - # accelerator.log(logs, step=global_step) - accelerator.log(logs) + accelerator.log(logs, step=global_step) # VALIDATION PER STEP - should_validate_epoch = ( + should_validate_step = ( args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0 ) - if validation_steps > 0 and should_validate_epoch: + if validation_steps > 0 and should_validate_step: accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( @@ -1406,6 +1395,9 @@ def remove_model(old_ckpt_name): if val_step >= validation_steps: break + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + loss = self.process_batch( batch, text_encoders, @@ -1428,18 +1420,22 @@ def remove_model(old_ckpt_name): val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) if is_tracking: - logs = {"loss/step_validation_current": current_loss} - # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) - accelerator.log(logs) + logs = { + "loss/validation/step/current": current_loss, + "val_step": (epoch * validation_steps) + val_step, + } + accelerator.log(logs, step=global_step) - logs = {"loss/step_validation_average": val_step_loss_recorder.moving_average} - # accelerator.log(logs, step=global_step) - accelerator.log(logs) + if is_tracking: + logs = { + "loss/validation/step/average": val_step_loss_recorder.moving_average, + } + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - # VALIDATION EPOCH + # EPOCH VALIDATION should_validate_epoch = ( (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None @@ -1458,6 +1454,9 @@ def remove_model(old_ckpt_name): if val_step >= validation_steps: break + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + loss = self.process_batch( batch, text_encoders, @@ -1480,21 +1479,22 @@ def remove_model(old_ckpt_name): val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) if is_tracking: - logs = {"loss/epoch_validation_current": current_loss} - # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) - accelerator.log(logs) + logs = { + "loss/validation/epoch_current": current_loss, + "epoch": epoch + 1, + "val_step": (epoch * validation_steps) + val_step + } + accelerator.log(logs, step=global_step) if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - logs = {"loss/epoch_validation_average": avr_loss} - # accelerator.log(logs, step=epoch + 1) - accelerator.log(logs) + logs = {"loss/validation/epoch_average": avr_loss, "epoch": epoch + 1} + accelerator.log(logs, step=global_step) # END OF EPOCH if is_tracking: - logs = {"loss/epoch_average": loss_recorder.moving_average} - # accelerator.log(logs, step=epoch + 1) - accelerator.log(logs) + logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} + accelerator.log(logs, step=global_step) accelerator.wait_for_everyone() From 556f3f1696eadcc16ee77425243b732a84c7a2aa Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 8 Jan 2025 13:41:15 -0500 Subject: [PATCH 305/348] Fix documentation, remove unused function, fix bucket reso for sd1.5, fix multiple datasets --- library/config_util.py | 6 +++--- library/train_util.py | 25 ++++--------------------- train_network.py | 5 +---- 3 files changed, 8 insertions(+), 28 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 63d28c969..de1e154a1 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -481,9 +481,9 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu subset_klass = FineTuningSubset dataset_klass = FineTuningDataset - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params)) - datasets.append(dataset) + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params)) + datasets.append(dataset) val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: diff --git a/library/train_util.py b/library/train_util.py index b8894752e..62aae37ef 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -152,11 +152,11 @@ def split_train_val(paths: List[str], is_training_dataset: bool, validation_spli Shuffle the dataset based on the validation_seed or the current random seed. For example if the split of 0.2 of 100 images. - [0:79] = 80 training images + [0:80] = 80 training images [80:] = 20 validation images """ if validation_seed is not None: - print(f"Using validation seed: {validation_seed}") + logging.info(f"Using validation seed: {validation_seed}") prevstate = random.getstate() random.seed(validation_seed) random.shuffle(paths) @@ -5900,8 +5900,8 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.Tensor: - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) +def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") timesteps = timesteps.long().to(device) return timesteps @@ -5964,23 +5964,6 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler return result -def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor: - """ - Add noise to the latents according to the noise magnitude at each timestep - (this is the forward diffusion process) - """ - if args.ip_noise_gamma: - if args.ip_noise_gamma_random_strength: - strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma - else: - strength = args.ip_noise_gamma - noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) - else: - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - return noisy_latents - - def conditional_loss( model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None ): diff --git a/train_network.py b/train_network.py index 199f589b0..7dbd12e88 100644 --- a/train_network.py +++ b/train_network.py @@ -125,10 +125,7 @@ def generate_step_logs( return logs def assert_extra_args(self, args, train_dataset_group): - # train_dataset_group.verify_bucket_reso_steps(64) - # TODO: Number of bucket reso steps may differ for each model, so a static number won't work - # and prevents models like SD1.5 with 64 - pass + train_dataset_group.verify_bucket_reso_steps(32) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) From 9fde0d797282c0cb9fcea01682e2e6e2eece47bc Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 8 Jan 2025 18:38:20 -0500 Subject: [PATCH 306/348] Handle tuple return from generate_dataset_group_by_blueprint --- fine_tune.py | 4 ++-- flux_train.py | 3 ++- flux_train_control_net.py | 4 ++-- library/config_util.py | 2 +- sd3_train.py | 3 ++- sdxl_train.py | 3 ++- sdxl_train_control_net.py | 2 +- sdxl_train_control_net_lllite.py | 2 +- sdxl_train_control_net_lllite_old.py | 2 +- tools/cache_latents.py | 3 ++- tools/cache_text_encoder_outputs.py | 3 ++- train_control_net.py | 2 +- train_db.py | 3 ++- train_textual_inversion.py | 3 ++- train_textual_inversion_XTI.py | 2 +- 15 files changed, 24 insertions(+), 17 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 176087065..6be2f98ca 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -91,9 +91,9 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args) + train_dataset_group, val_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/flux_train.py b/flux_train.py index fced3bef9..6f98adea8 100644 --- a/flux_train.py +++ b/flux_train.py @@ -138,9 +138,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 9d36a41d3..54dec2a77 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -126,9 +126,9 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args) + train_dataset_group, val_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/library/config_util.py b/library/config_util.py index de1e154a1..834d6bfaf 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -467,7 +467,7 @@ def search_value(key: str, fallbacks: Sequence[dict], default_value=None): return default_value -def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): +def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint) -> Tuple[DatasetGroup, Optional[DatasetGroup]]: datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: diff --git a/sd3_train.py b/sd3_train.py index 120455e7b..3bff6a50f 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -149,9 +149,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/sdxl_train.py b/sdxl_train.py index b9d529243..a60f6df63 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -176,9 +176,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index ffbf03cab..c6e8136f7 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -114,7 +114,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 365059b75..00e51a673 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -123,7 +123,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 5b372befc..63457cc61 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -103,7 +103,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/tools/cache_latents.py b/tools/cache_latents.py index c034f949a..515ece98d 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -116,10 +116,11 @@ def cache_to_disk(args: argparse.Namespace) -> None: } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None # acceleratorを準備する logger.info("prepare accelerator") diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 5888b8e3d..00459658e 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -103,10 +103,11 @@ def cache_to_disk(args: argparse.Namespace) -> None: } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None # acceleratorを準備する logger.info("prepare accelerator") diff --git a/train_control_net.py b/train_control_net.py index 177d2b11f..ba016ac5d 100644 --- a/train_control_net.py +++ b/train_control_net.py @@ -100,7 +100,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/train_db.py b/train_db.py index ad21f8d1b..edd674034 100644 --- a/train_db.py +++ b/train_db.py @@ -89,9 +89,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 65da4859b..113f35997 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -320,9 +320,10 @@ def train(self, args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None self.assert_extra_args(args, train_dataset_group) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 2a2b42310..6ff97d03f 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -239,7 +239,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings) current_epoch = Value("i", 0) current_step = Value("i", 0) From 1e61392cf2f601e1c66aaede6846ef70f599c34f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 8 Jan 2025 18:43:26 -0500 Subject: [PATCH 307/348] Revert bucket_reso_steps to correct 64 --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 7dbd12e88..7e9f12659 100644 --- a/train_network.py +++ b/train_network.py @@ -125,7 +125,7 @@ def generate_step_logs( return logs def assert_extra_args(self, args, train_dataset_group): - train_dataset_group.verify_bucket_reso_steps(32) + train_dataset_group.verify_bucket_reso_steps(64) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) From d6f158ddf6a3631df7db10ac97453b12de8eadbe Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 8 Jan 2025 18:48:05 -0500 Subject: [PATCH 308/348] Fix incorrect destructoring for load_abritrary_dataset --- fine_tune.py | 3 ++- flux_train_control_net.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 6be2f98ca..e1ed47496 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -93,7 +93,8 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args) train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group, val_dataset_group = train_util.load_arbitrary_dataset(args) + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 54dec2a77..cecd00019 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -128,7 +128,8 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args) train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group, val_dataset_group = train_util.load_arbitrary_dataset(args) + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) From 264167fa1636c79f106c63c3cdb67b6bee80aceb Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 9 Jan 2025 12:43:58 -0500 Subject: [PATCH 309/348] Apply is_training_dataset only to DreamBoothDataset. Add validation_split check and warning --- library/config_util.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 834d6bfaf..a2e07dc6c 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -471,36 +471,49 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: + extra_dataset_params = {} + if dataset_blueprint.is_controlnet: subset_klass = ControlNetSubset dataset_klass = ControlNetDataset elif dataset_blueprint.is_dreambooth: subset_klass = DreamBoothSubset dataset_klass = DreamBoothDataset + # DreamBooth datasets support splitting training and validation datasets + extra_dataset_params = {"is_training_dataset": True} else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params) datasets.append(dataset) val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: - if dataset_blueprint.params.validation_split <= 0.0: + if dataset_blueprint.params.validation_split < 0.0 or dataset_blueprint.params.validation_split > 1.0: + logging.warning(f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split...") + continue + + # if the dataset isn't setting a validation split, there is no current validation dataset + if dataset_blueprint.params.validation_split == 0.0: continue + + extra_dataset_params = {} if dataset_blueprint.is_controlnet: subset_klass = ControlNetSubset dataset_klass = ControlNetDataset elif dataset_blueprint.is_dreambooth: subset_klass = DreamBoothSubset dataset_klass = DreamBoothDataset + # DreamBooth datasets support splitting training and validation datasets + extra_dataset_params = {"is_training_dataset": False} else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, is_training_dataset=False, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params) val_datasets.append(dataset) def print_info(_datasets, dataset_type: str): From 4c61adc9965df6861ae3705c96143f4299074744 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 13:18:26 -0500 Subject: [PATCH 310/348] Add divergence to logs Divergence is the difference between training and validation to allow a clear value to indicate the difference between the two in the logs. --- train_network.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 7e9f12659..5ed92b7e2 100644 --- a/train_network.py +++ b/train_network.py @@ -1418,14 +1418,16 @@ def remove_model(old_ckpt_name): if is_tracking: logs = { - "loss/validation/step/current": current_loss, + "loss/validation/step_current": current_loss, "val_step": (epoch * validation_steps) + val_step, } accelerator.log(logs, step=global_step) if is_tracking: + loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average logs = { - "loss/validation/step/average": val_step_loss_recorder.moving_average, + "loss/validation/step_average": val_step_loss_recorder.moving_average, + "loss/validation/step_divergence": loss_validation_divergence, } accelerator.log(logs, step=global_step) @@ -1485,7 +1487,12 @@ def remove_model(old_ckpt_name): if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - logs = {"loss/validation/epoch_average": avr_loss, "epoch": epoch + 1} + loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss + logs = { + "loss/validation/epoch_average": avr_loss, + "loss/validation/epoch_divergence": loss_validation_divergence, + "epoch": epoch + 1 + } accelerator.log(logs, step=global_step) # END OF EPOCH From 2bbb40ce51d5be3ce8c3e1990d30455201f9e852 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 14:29:50 -0500 Subject: [PATCH 311/348] Fix regularization images with validation Adding metadata recording for validation arguments Add comments about the validation split for clarity of intention --- library/train_util.py | 33 +++++++++++++++++++++++++++++++-- train_network.py | 7 +++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 62aae37ef..6d3a772bb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -146,7 +146,12 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" -def split_train_val(paths: List[str], is_training_dataset: bool, validation_split: float, validation_seed: int) -> List[str]: +def split_train_val( + paths: List[str], + is_training_dataset: bool, + validation_split: float, + validation_seed: int | None +) -> List[str]: """ Split the dataset into train and validation @@ -1830,6 +1835,9 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): class DreamBoothDataset(BaseDataset): IMAGE_INFO_CACHE_FILE = "metadata_cache.json" + # The is_training_dataset defines the type of dataset, training or validation + # if is_training_dataset is True -> training dataset + # if is_training_dataset is False -> validation dataset def __init__( self, subsets: Sequence[DreamBoothSubset], @@ -1965,8 +1973,29 @@ def load_dreambooth_dir(subset: DreamBoothSubset): size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + # We want to create a training and validation split. This should be improved in the future + # to allow a clearer distinction between training and validation. This can be seen as a + # short-term solution to limit what is necessary to implement validation datasets + # + # We split the dataset for the subset based on if we are doing a validation split + # The self.is_training_dataset defines the type of dataset, training or validation + # if self.is_training_dataset is True -> training dataset + # if self.is_training_dataset is False -> validation dataset if self.validation_split > 0.0: - img_paths = split_train_val(img_paths, self.is_training_dataset, self.validation_split, self.validation_seed) + # For regularization images we do not want to split this dataset. + if subset.is_reg is True: + # Skip any validation dataset for regularization images + if self.is_training_dataset is False: + img_paths = [] + # Otherwise the img_paths remain as original img_paths and no split + # required for training images dataset of regularization images + else: + img_paths = split_train_val( + img_paths, + self.is_training_dataset, + self.validation_split, + self.validation_seed + ) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") diff --git a/train_network.py b/train_network.py index 5ed92b7e2..605dbc60c 100644 --- a/train_network.py +++ b/train_network.py @@ -898,6 +898,7 @@ def load_model_hook(models, input_dir): accelerator.print("running training / 学習開始") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}") accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") @@ -917,6 +918,7 @@ def load_model_hook(models, input_dir): "ss_text_encoder_lr": text_encoder_lr, "ss_unet_lr": args.unet_lr, "ss_num_train_images": train_dataset_group.num_train_images, + "ss_num_validation_images": val_dataset_group.num_train_images if val_dataset_group is not None else 0, "ss_num_reg_images": train_dataset_group.num_reg_images, "ss_num_batches_per_epoch": len(train_dataloader), "ss_num_epochs": num_train_epochs, @@ -964,6 +966,11 @@ def load_model_hook(models, input_dir): "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), + "ss_validation_seed": args.validation_seed, + "ss_validation_split": args.validation_split, + "ss_max_validation_steps": args.max_validation_steps, + "ss_validate_every_n_epochs": args.validate_every_n_epochs, + "ss_validate_every_n_steps": args.validate_every_n_steps, } self.update_metadata(metadata, args) # architecture specific metadata From 0456858992909ca0b821ec1b2ca40fa633113224 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 14:47:49 -0500 Subject: [PATCH 312/348] Fix validate_every_n_steps always running first step --- train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train_network.py b/train_network.py index 605dbc60c..75e36dca9 100644 --- a/train_network.py +++ b/train_network.py @@ -1385,6 +1385,7 @@ def remove_model(old_ckpt_name): # VALIDATION PER STEP should_validate_step = ( args.validate_every_n_steps is not None + and global_step != 0 # Skip first step and global_step % args.validate_every_n_steps == 0 ) if validation_steps > 0 and should_validate_step: From ee9265cf2678df5c9dfa6c1148d20fb738a9e6ce Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 14:56:35 -0500 Subject: [PATCH 313/348] Fix validate_every_n_steps for gradient accumulation --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 75e36dca9..2f3203c94 100644 --- a/train_network.py +++ b/train_network.py @@ -1388,7 +1388,7 @@ def remove_model(old_ckpt_name): and global_step != 0 # Skip first step and global_step % args.validate_every_n_steps == 0 ) - if validation_steps > 0 and should_validate_step: + if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( From 25929dd0d733144859008479c374968102e5d3a3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 15:38:57 -0500 Subject: [PATCH 314/348] Remove Validating... print to fix output layout --- train_network.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/train_network.py b/train_network.py index 2f3203c94..e7d93a108 100644 --- a/train_network.py +++ b/train_network.py @@ -1389,8 +1389,6 @@ def remove_model(old_ckpt_name): and global_step % args.validate_every_n_steps == 0 ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: - accelerator.print("Validating バリデーション処理...") - val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, @@ -1450,7 +1448,6 @@ def remove_model(old_ckpt_name): ) if should_validate_epoch and len(val_dataloader) > 0: - accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, From b489082495ba6779385f282797227799413715f5 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 16:42:04 -0500 Subject: [PATCH 315/348] Disable repeats for validation datasets --- library/train_util.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 6d3a772bb..4d143c373 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2055,9 +2055,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset): num_reg_images = 0 reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] for subset in subsets: - if subset.num_repeats < 1: + num_repeats = subset.num_repeats if self.is_training_dataset else 1 + if num_repeats < 1: logger.warning( - f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" + f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {num_repeats}" ) continue @@ -2075,12 +2076,12 @@ def load_dreambooth_dir(subset: DreamBoothSubset): continue if subset.is_reg: - num_reg_images += subset.num_repeats * len(img_paths) + num_reg_images += num_repeats * len(img_paths) else: - num_train_images += subset.num_repeats * len(img_paths) + num_train_images += num_repeats * len(img_paths) for img_path, caption, size in zip(img_paths, captions, sizes): - info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) if size is not None: info.image_size = size if subset.is_reg: From c04e5dfe92250a4790dc5f6e092cd85809a4e81d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 23 Jan 2025 09:57:24 -0500 Subject: [PATCH 316/348] Fix loss recorder on 0. Fix validation for cached runs. Assert on validation dataset --- flux_train_network.py | 8 +++++--- library/train_util.py | 8 +++++++- requirements.txt | 1 + sd3_train_network.py | 11 ++++++++--- sdxl_train_network.py | 8 +++++--- sdxl_train_textual_inversion.py | 5 +++-- train_network.py | 16 +++++++++++----- train_textual_inversion.py | 9 ++++++--- 8 files changed, 46 insertions(+), 20 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index b3aebecc7..5cd1b9d51 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -2,7 +2,7 @@ import copy import math import random -from typing import Any, Optional +from typing import Any, Optional, Union import torch from accelerate import Accelerator @@ -36,8 +36,8 @@ def __init__(self): self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + super().assert_extra_args(args, train_dataset_group, val_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) if args.fp8_base_unet: @@ -80,6 +80,8 @@ def assert_extra_args(self, args, train_dataset_group): args.blocks_to_swap = 18 # 18 is safe for most cases train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) # TODO check this def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models diff --git a/library/train_util.py b/library/train_util.py index 4d143c373..56fea4a8c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2893,6 +2893,9 @@ def __getitem__(self, idx): """ raise NotImplementedError + def get_resolutions(self) -> List[Tuple[int, int]]: + return [] + def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset: module = ".".join(args.dataset_class.split(".")[:-1]) @@ -6520,4 +6523,7 @@ def add(self, *, epoch: int, step: int, loss: float) -> None: @property def moving_average(self) -> float: - return self.loss_total / len(self.loss_list) + losses = len(self.loss_list) + if losses == 0: + return 0 + return self.loss_total / losses diff --git a/requirements.txt b/requirements.txt index e0091749a..de39f5887 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,7 @@ voluptuous==0.13.1 huggingface-hub==0.24.5 # for Image utils imagesize==1.4.1 +numpy<=2.0 # for BLIP captioning # requests==2.28.2 # timm==0.6.12 diff --git a/sd3_train_network.py b/sd3_train_network.py index c7417802d..dcf497f53 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -2,7 +2,7 @@ import copy import math import random -from typing import Any, Optional +from typing import Any, Optional, Union import torch from accelerate import Accelerator @@ -26,7 +26,7 @@ def __init__(self): super().__init__() self.sample_prompts_te_outputs = None - def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup): + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): # super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -56,9 +56,14 @@ def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup): ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) # TODO check this # enumerate resolutions from dataset for positional embeddings - self.resolutions = train_dataset_group.get_resolutions() + resolutions = train_dataset_group.get_resolutions() + if val_dataset_group is not None: + resolutions = resolutions + val_dataset_group.get_resolutions() + self.resolutions = resolutions def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models diff --git a/sdxl_train_network.py b/sdxl_train_network.py index d45df6e05..eb09831ec 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,5 +1,5 @@ import argparse -from typing import List, Optional +from typing import List, Optional, Union import torch from accelerate import Accelerator @@ -23,8 +23,8 @@ def __init__(self): self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR self.is_sdxl = True - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + super().assert_extra_args(args, train_dataset_group, val_dataset_group) sdxl_train_util.verify_sdxl_training_args(args) if args.cache_text_encoder_outputs: @@ -37,6 +37,8 @@ def assert_extra_args(self, args, train_dataset_group): ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" train_dataset_group.verify_bucket_reso_steps(32) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) def load_target_model(self, args, weight_dtype, accelerator): ( diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 821a69558..bf56faf34 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -18,11 +18,12 @@ def __init__(self): self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR self.is_sdxl = True - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + super().assert_extra_args(args, train_dataset_group, val_dataset_group) sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) train_dataset_group.verify_bucket_reso_steps(32) + val_dataset_group.verify_bucket_reso_steps(32) def load_target_model(self, args, weight_dtype, accelerator): ( diff --git a/train_network.py b/train_network.py index e7d93a108..2c3bb2aae 100644 --- a/train_network.py +++ b/train_network.py @@ -3,7 +3,7 @@ import math import os import typing -from typing import Any, List +from typing import Any, List, Union, Optional import sys import random import time @@ -124,8 +124,10 @@ def generate_step_logs( return logs - def assert_extra_args(self, args, train_dataset_group): + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): train_dataset_group.verify_bucket_reso_steps(64) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(64) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) @@ -512,7 +514,7 @@ def train(self, args): val_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - self.assert_extra_args(args, train_dataset_group) # may change some args + self.assert_extra_args(args, train_dataset_group, val_dataset_group) # may change some args # acceleratorを準備する logger.info("preparing accelerator") @@ -1414,7 +1416,9 @@ def remove_model(old_ckpt_name): args, text_encoding_strategy, tokenize_strategy, - is_train=False + is_train=False, + train_text_encoder=False, + train_unet=False ) current_loss = loss.detach().item() @@ -1474,7 +1478,9 @@ def remove_model(old_ckpt_name): args, text_encoding_strategy, tokenize_strategy, - is_train=False + is_train=False, + train_text_encoder=False, + train_unet=False ) current_loss = loss.detach().item() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 113f35997..0c6568b08 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -2,7 +2,7 @@ import math import os from multiprocessing import Value -from typing import Any, List +from typing import Any, List, Optional, Union import toml from tqdm import tqdm @@ -99,9 +99,12 @@ def __init__(self): self.vae_scale_factor = 0.18215 self.is_sdxl = False - def assert_extra_args(self, args, train_dataset_group): + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): train_dataset_group.verify_bucket_reso_steps(64) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(64) + def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet @@ -325,7 +328,7 @@ def train(self, args): train_dataset_group = train_util.load_arbitrary_dataset(args) val_dataset_group = None - self.assert_extra_args(args, train_dataset_group) + self.assert_extra_args(args, train_dataset_group, val_dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) From 58b82a576e32c2157e476840339ddafa98222dfc Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 26 Jan 2025 21:21:21 +0900 Subject: [PATCH 317/348] Fix to work with validation dataset --- library/train_util.py | 1 + sdxl_train_textual_inversion.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 56fea4a8c..37ed0a994 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2403,6 +2403,7 @@ def __init__( self.dreambooth_dataset_delegate = DreamBoothDataset( db_subsets, + True, batch_size, resolution, network_multiplier, diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index bf56faf34..982007601 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -1,5 +1,6 @@ import argparse import os +from typing import Optional, Union import regex @@ -23,7 +24,8 @@ def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetG sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) train_dataset_group.verify_bucket_reso_steps(32) - val_dataset_group.verify_bucket_reso_steps(32) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) def load_target_model(self, args, weight_dtype, accelerator): ( From e8529613d8a06ce91d3b304bccf85a172b1b4b31 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 26 Jan 2025 21:27:22 +0900 Subject: [PATCH 318/348] README.md: Update recent updates section to include validation loss support for training scripts --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 4dff15440..053354103 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,13 @@ The command to install PyTorch is as follows: ### Recent Updates +Jan 25, 2025: + +- `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO! + - For details on how to set it up, please refer to the PR. The documentation will be updated as needed. + - It will be added to other scripts as well. + - As a current limitation, validation loss is not supported when `--block_to_swap` is specified. + Dec 15, 2024: - RAdamScheduleFree optimizer is supported. PR [#1830](https://github.com/kohya-ss/sd-scripts/pull/1830) Thanks to nhamanasu! From 59b3b94faf827e3a7f01829fed0232d89dec9e33 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 26 Jan 2025 21:52:58 +0900 Subject: [PATCH 319/348] README.md: Update limitation for validation loss support to include schedule-free optimizer --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 053354103..4bbd7617e 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Jan 25, 2025: - `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO! - For details on how to set it up, please refer to the PR. The documentation will be updated as needed. - It will be added to other scripts as well. - - As a current limitation, validation loss is not supported when `--block_to_swap` is specified. + - As a current limitation, validation loss is not supported when `--block_to_swap` is specified, or when schedule-free optimizer is used. Dec 15, 2024: From 532f5c58a6e83a3400f82103f5854ff3f63d77d7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 20:50:42 +0900 Subject: [PATCH 320/348] formatting --- train_network.py | 229 ++++++++++++++++++++++------------------------- 1 file changed, 108 insertions(+), 121 deletions(-) diff --git a/train_network.py b/train_network.py index 2c3bb2aae..cc54be7cc 100644 --- a/train_network.py +++ b/train_network.py @@ -100,9 +100,7 @@ def generate_step_logs( if ( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None ): # tracking d*lr value of unet. - logs["lr/d*lr"] = ( - optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] - ) + logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] else: idx = 0 if not args.network_train_unet_only: @@ -115,16 +113,17 @@ def generate_step_logs( logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) - if ( - args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None - ): - logs[f"lr/d*lr/group{i}"] = ( - optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] - ) + if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None: + logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] return logs - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): train_dataset_group.verify_bucket_reso_steps(64) if val_dataset_group is not None: val_dataset_group.verify_bucket_reso_steps(64) @@ -219,7 +218,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -315,22 +314,22 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, # endregion def process_batch( - self, - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy: strategy_base.TextEncodingStrategy, - tokenize_strategy: strategy_base.TokenizeStrategy, - is_train=True, - train_text_encoder=True, - train_unet=True + self, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy: strategy_base.TextEncodingStrategy, + tokenize_strategy: strategy_base.TokenizeStrategy, + is_train=True, + train_text_encoder=True, + train_unet=True, ) -> torch.Tensor: """ Process a batch for the network @@ -397,7 +396,7 @@ def process_batch( network, weight_dtype, train_unet, - is_train=is_train + is_train=is_train, ) huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) @@ -484,7 +483,7 @@ def train(self, args): else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) - val_dataset_group = None # placeholder until validation dataset supported for arbitrary + val_dataset_group = None # placeholder until validation dataset supported for arbitrary current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -701,7 +700,7 @@ def train(self, args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) - + val_dataloader = torch.utils.data.DataLoader( val_dataset_group if val_dataset_group is not None else [], shuffle=False, @@ -900,7 +899,9 @@ def load_model_hook(models, input_dir): accelerator.print("running training / 学習開始") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}") + accelerator.print( + f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}" + ) accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") @@ -968,11 +969,11 @@ def load_model_hook(models, input_dir): "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), - "ss_validation_seed": args.validation_seed, - "ss_validation_split": args.validation_split, - "ss_max_validation_steps": args.max_validation_steps, - "ss_validate_every_n_epochs": args.validate_every_n_epochs, - "ss_validate_every_n_steps": args.validate_every_n_steps, + "ss_validation_seed": args.validation_seed, + "ss_validation_split": args.validation_split, + "ss_max_validation_steps": args.max_validation_steps, + "ss_validate_every_n_epochs": args.validate_every_n_epochs, + "ss_validate_every_n_steps": args.validate_every_n_steps, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1248,9 +1249,7 @@ def remove_model(old_ckpt_name): accelerator.log({}, step=0) validation_steps = ( - min(args.max_validation_steps, len(val_dataloader)) - if args.max_validation_steps is not None - else len(val_dataloader) + min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) ) # training loop @@ -1298,21 +1297,21 @@ def remove_model(old_ckpt_name): self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=True, - train_text_encoder=train_text_encoder, - train_unet=train_unet + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=True, + train_text_encoder=train_text_encoder, + train_unet=train_unet, ) accelerator.backward(loss) @@ -1369,32 +1368,21 @@ def remove_model(old_ckpt_name): if args.scale_weight_norms: progress_bar.set_postfix(**{**max_mean_logs, **logs}) - if is_tracking: logs = self.generate_step_logs( - args, - current_loss, - avr_loss, - lr_scheduler, - lr_descriptions, - optimizer, - keys_scaled, - mean_norm, - maximum_norm + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) accelerator.log(logs, step=global_step) # VALIDATION PER STEP should_validate_step = ( - args.validate_every_n_steps is not None - and global_step != 0 # Skip first step + args.validate_every_n_steps is not None + and global_step != 0 # Skip first step and global_step % args.validate_every_n_steps == 0 ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: val_progress_bar = tqdm( - range(validation_steps), smoothing=0, - disable=not accelerator.is_local_main_process, - desc="validation steps" + range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" ) for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: @@ -1404,27 +1392,27 @@ def remove_model(old_ckpt_name): self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, is_train=False, - train_text_encoder=False, - train_unet=False + train_text_encoder=False, + train_unet=False, ) current_loss = loss.detach().item() val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) + val_progress_bar.set_postfix({"val_avg_loss": val_step_loss_recorder.moving_average}) if is_tracking: logs = { @@ -1436,26 +1424,25 @@ def remove_model(old_ckpt_name): if is_tracking: loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average logs = { - "loss/validation/step_average": val_step_loss_recorder.moving_average, - "loss/validation/step_divergence": loss_validation_divergence, + "loss/validation/step_average": val_step_loss_recorder.moving_average, + "loss/validation/step_divergence": loss_validation_divergence, } accelerator.log(logs, step=global_step) - + if global_step >= args.max_train_steps: break # EPOCH VALIDATION should_validate_epoch = ( - (epoch + 1) % args.validate_every_n_epochs == 0 - if args.validate_every_n_epochs is not None - else True + (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True ) if should_validate_epoch and len(val_dataloader) > 0: val_progress_bar = tqdm( - range(validation_steps), smoothing=0, - disable=not accelerator.is_local_main_process, - desc="epoch validation steps" + range(validation_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="epoch validation steps", ) for val_step, batch in enumerate(val_dataloader): @@ -1466,43 +1453,43 @@ def remove_model(old_ckpt_name): self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, is_train=False, - train_text_encoder=False, - train_unet=False + train_text_encoder=False, + train_unet=False, ) current_loss = loss.detach().item() val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) + val_progress_bar.set_postfix({"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average}) if is_tracking: logs = { - "loss/validation/epoch_current": current_loss, - "epoch": epoch + 1, - "val_step": (epoch * validation_steps) + val_step + "loss/validation/epoch_current": current_loss, + "epoch": epoch + 1, + "val_step": (epoch * validation_steps) + val_step, } accelerator.log(logs, step=global_step) if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss + loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss logs = { - "loss/validation/epoch_average": avr_loss, - "loss/validation/epoch_divergence": loss_validation_divergence, - "epoch": epoch + 1 + "loss/validation/epoch_average": avr_loss, + "loss/validation/epoch_divergence": loss_validation_divergence, + "epoch": epoch + 1, } accelerator.log(logs, step=global_step) @@ -1510,7 +1497,7 @@ def remove_model(old_ckpt_name): if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} accelerator.log(logs, step=global_step) - + accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 @@ -1696,31 +1683,31 @@ def setup_parser() -> argparse.ArgumentParser: "--validation_seed", type=int, default=None, - help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する" + help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する", ) parser.add_argument( "--validation_split", type=float, default=0.0, - help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" + help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合", ) parser.add_argument( "--validate_every_n_steps", type=int, default=None, - help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます" + help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます", ) parser.add_argument( "--validate_every_n_epochs", type=int, default=None, - help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます" + help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます", ) parser.add_argument( "--max_validation_steps", type=int, default=None, - help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します" + help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します", ) return parser From 86a2f3fd262e52b3249d9f5508efe4774f1fa3ed Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 21:10:52 +0900 Subject: [PATCH 321/348] Fix gradient handling when Text Encoders are trained --- flux_train_network.py | 43 ++----------------------------------------- sd3_train_network.py | 2 +- train_network.py | 10 +++++----- 3 files changed, 8 insertions(+), 47 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 5cd1b9d51..475bd751b 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -376,9 +376,8 @@ def get_noise_pred_and_target( t5_attn_mask = None def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): - # if not args.split_mode: - # normal forward - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): + # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode + with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( img=img, @@ -390,44 +389,6 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) - """ - else: - # split forward to reduce memory usage - assert network.train_blocks == "single", "train_blocks must be single for split mode" - with accelerator.autocast(): - # move flux lower to cpu, and then move flux upper to gpu - unet.to("cpu") - clean_memory_on_device(accelerator.device) - self.flux_upper.to(accelerator.device) - - # upper model does not require grad - with torch.no_grad(): - intermediate_img, intermediate_txt, vec, pe = self.flux_upper( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - - # move flux upper back to cpu, and then move flux lower to gpu - self.flux_upper.to("cpu") - clean_memory_on_device(accelerator.device) - unet.to(accelerator.device) - - # lower model requires grad - intermediate_img.requires_grad_(True) - intermediate_txt.requires_grad_(True) - vec.requires_grad_(True) - pe.requires_grad_(True) - - with torch.set_grad_enabled(is_train and train_unet): - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) - """ - return model_pred model_pred = call_dit( diff --git a/sd3_train_network.py b/sd3_train_network.py index dcf497f53..2f4579492 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -345,7 +345,7 @@ def get_noise_pred_and_target( t5_attn_mask = None # call model - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): + with torch.set_grad_enabled(is_train), accelerator.autocast(): # TODO support attention mask model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) diff --git a/train_network.py b/train_network.py index cc54be7cc..6f1652fd9 100644 --- a/train_network.py +++ b/train_network.py @@ -232,7 +232,7 @@ def get_noise_pred_and_target( t.requires_grad_(True) # Predict the noise residual - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): + with torch.set_grad_enabled(is_train), accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, @@ -1405,8 +1405,8 @@ def remove_model(old_ckpt_name): text_encoding_strategy, tokenize_strategy, is_train=False, - train_text_encoder=False, - train_unet=False, + train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True + train_unet=train_unet, ) current_loss = loss.detach().item() @@ -1466,8 +1466,8 @@ def remove_model(old_ckpt_name): text_encoding_strategy, tokenize_strategy, is_train=False, - train_text_encoder=False, - train_unet=False, + train_text_encoder=train_text_encoder, + train_unet=train_unet, ) current_loss = loss.detach().item() From b6a309321675b5d0a59b776ffb4d0ecdd3d28ec2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 21:22:11 +0900 Subject: [PATCH 322/348] call optimizer eval/train fn before/after validation --- train_network.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/train_network.py b/train_network.py index 6f1652fd9..e735c582d 100644 --- a/train_network.py +++ b/train_network.py @@ -1381,6 +1381,8 @@ def remove_model(old_ckpt_name): and global_step % args.validate_every_n_steps == 0 ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: + optimizer_eval_fn() + val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" ) @@ -1429,6 +1431,8 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + optimizer_train_fn() + if global_step >= args.max_train_steps: break @@ -1438,6 +1442,8 @@ def remove_model(old_ckpt_name): ) if should_validate_epoch and len(val_dataloader) > 0: + optimizer_eval_fn() + val_progress_bar = tqdm( range(validation_steps), smoothing=0, @@ -1493,6 +1499,8 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + optimizer_train_fn() + # END OF EPOCH if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} From 29f31d005f12a08650389164fa9c60504928d451 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 21:35:43 +0900 Subject: [PATCH 323/348] add network.train()/eval() for validation --- train_network.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index e735c582d..9b8036f8b 100644 --- a/train_network.py +++ b/train_network.py @@ -1276,7 +1276,7 @@ def remove_model(old_ckpt_name): metadata["ss_epoch"] = str(epoch + 1) - accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) + accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # network.train() is called here # TRAINING skipped_dataloader = None @@ -1382,6 +1382,7 @@ def remove_model(old_ckpt_name): ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() + accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" @@ -1432,6 +1433,7 @@ def remove_model(old_ckpt_name): accelerator.log(logs, step=global_step) optimizer_train_fn() + accelerator.unwrap_model(network).train() if global_step >= args.max_train_steps: break @@ -1443,6 +1445,7 @@ def remove_model(old_ckpt_name): if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() + accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( range(validation_steps), @@ -1500,6 +1503,7 @@ def remove_model(old_ckpt_name): accelerator.log(logs, step=global_step) optimizer_train_fn() + accelerator.unwrap_model(network).train() # END OF EPOCH if is_tracking: From 0750859133eec7858052cd3f79106113fa786e94 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 21:56:59 +0900 Subject: [PATCH 324/348] validation: Implement timestep-based validation processing --- sd3_train_network.py | 1 + train_network.py | 167 +++++++++++++++++++++++++------------------ 2 files changed, 100 insertions(+), 68 deletions(-) diff --git a/sd3_train_network.py b/sd3_train_network.py index 2f4579492..d4f131252 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -446,6 +446,7 @@ def forward(hidden_states): prepare_fp8(text_encoder, weight_dtype) def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + # TODO consider validation # drop cached text encoder outputs text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: diff --git a/train_network.py b/train_network.py index 9b8036f8b..a63e9d1e9 100644 --- a/train_network.py +++ b/train_network.py @@ -9,6 +9,7 @@ import time import json from multiprocessing import Value +import numpy as np import toml from tqdm import tqdm @@ -1248,10 +1249,6 @@ def remove_model(old_ckpt_name): # log empty object to commit the sample images to wandb accelerator.log({}, step=0) - validation_steps = ( - min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - ) - # training loop if initial_step > 0: # only if skip_until_initial_step is specified for skip_epoch in range(epoch_to_start): # skip epochs @@ -1270,6 +1267,17 @@ def remove_model(old_ckpt_name): clean_memory_on_device(accelerator.device) + validation_steps = ( + min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + ) + NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable + min_timestep = 0 if args.min_timestep is None else args.min_timestep + max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep + validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1] + validation_total_steps = validation_steps * len(validation_timesteps) + original_args_min_timestep = args.min_timestep + original_args_max_timestep = args.max_timestep + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 @@ -1385,44 +1393,55 @@ def remove_model(old_ckpt_name): accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( - range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" + range(validation_total_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="validation steps", ) + val_ts_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - - loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=False, - train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True - train_unet=train_unet, - ) - - current_loss = loss.detach().item() - val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) - val_progress_bar.update(1) - val_progress_bar.set_postfix({"val_avg_loss": val_step_loss_recorder.moving_average}) - - if is_tracking: - logs = { - "loss/validation/step_current": current_loss, - "val_step": (epoch * validation_steps) + val_step, - } - accelerator.log(logs, step=global_step) + for timestep in validation_timesteps: + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + + args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep + + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False, + train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True + train_unet=train_unet, + ) + + current_loss = loss.detach().item() + val_step_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix( + {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} + ) + + if is_tracking: + logs = { + "loss/validation/step_current": current_loss, + "val_step": (epoch * validation_total_steps) + val_ts_step, + } + accelerator.log(logs, step=global_step) + + val_ts_step += 1 if is_tracking: loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average @@ -1432,6 +1451,8 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + args.min_timestep = original_args_min_timestep + args.max_timestep = original_args_max_timestep optimizer_train_fn() accelerator.unwrap_model(network).train() @@ -1448,49 +1469,57 @@ def remove_model(old_ckpt_name): accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( - range(validation_steps), + range(validation_total_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="epoch validation steps", ) + val_ts_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + for timestep in validation_timesteps: + args.min_timestep = args.max_timestep = timestep - loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=False, - train_text_encoder=train_text_encoder, - train_unet=train_unet, - ) + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False, + train_text_encoder=train_text_encoder, + train_unet=train_unet, + ) - current_loss = loss.detach().item() - val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) - val_progress_bar.update(1) - val_progress_bar.set_postfix({"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average}) + current_loss = loss.detach().item() + val_epoch_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix( + {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep} + ) - if is_tracking: - logs = { - "loss/validation/epoch_current": current_loss, - "epoch": epoch + 1, - "val_step": (epoch * validation_steps) + val_step, - } - accelerator.log(logs, step=global_step) + if is_tracking: + logs = { + "loss/validation/epoch_current": current_loss, + "epoch": epoch + 1, + "val_step": (epoch * validation_total_steps) + val_ts_step, + } + accelerator.log(logs, step=global_step) + + val_ts_step += 1 if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average @@ -1502,6 +1531,8 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + args.min_timestep = original_args_min_timestep + args.max_timestep = original_args_max_timestep optimizer_train_fn() accelerator.unwrap_model(network).train() From 0778dd9b1df0d6aa33287ded3ce4195f3d03251b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 22:03:42 +0900 Subject: [PATCH 325/348] fix Text Encoder only LoRA training --- flux_train_network.py | 2 +- sd3_train_network.py | 2 +- train_network.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 5cd1b9d51..ae4b62f5c 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -378,7 +378,7 @@ def get_noise_pred_and_target( def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): # if not args.split_mode: # normal forward - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): + with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( img=img, diff --git a/sd3_train_network.py b/sd3_train_network.py index dcf497f53..2f4579492 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -345,7 +345,7 @@ def get_noise_pred_and_target( t5_attn_mask = None # call model - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): + with torch.set_grad_enabled(is_train), accelerator.autocast(): # TODO support attention mask model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) diff --git a/train_network.py b/train_network.py index 2c3bb2aae..c3879531d 100644 --- a/train_network.py +++ b/train_network.py @@ -233,7 +233,7 @@ def get_noise_pred_and_target( t.requires_grad_(True) # Predict the noise residual - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): + with torch.set_grad_enabled(is_train), accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, From 45ec02b2a8b5eb5af8f5b4877381dc4dcc596cb9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 22:10:38 +0900 Subject: [PATCH 326/348] use same noise for every validation --- flux_train_network.py | 1 - train_network.py | 6 ++++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/flux_train_network.py b/flux_train_network.py index aab025735..475bd751b 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -377,7 +377,6 @@ def get_noise_pred_and_target( def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode - with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( diff --git a/train_network.py b/train_network.py index a63e9d1e9..f0deb67ab 100644 --- a/train_network.py +++ b/train_network.py @@ -1391,6 +1391,8 @@ def remove_model(old_ckpt_name): if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() + rng_state = torch.get_rng_state() + torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1451,6 +1453,7 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + torch.set_rng_state(rng_state) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() @@ -1467,6 +1470,8 @@ def remove_model(old_ckpt_name): if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() accelerator.unwrap_model(network).eval() + rng_state = torch.get_rng_state() + torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1531,6 +1536,7 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + torch.set_rng_state(rng_state) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() From de830b89416f0671d7a1364a9262fa850c0669df Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 29 Jan 2025 00:02:45 -0500 Subject: [PATCH 327/348] Move progress bar to account for sampling image first --- train_network.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index c3879531d..2deb736d6 100644 --- a/train_network.py +++ b/train_network.py @@ -1163,10 +1163,6 @@ def load_model_hook(models, input_dir): args.max_train_steps > initial_step ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}" - progress_bar = tqdm( - range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" - ) - epoch_to_start = 0 if initial_step > 0: if args.skip_until_initial_step: @@ -1271,6 +1267,10 @@ def remove_model(old_ckpt_name): clean_memory_on_device(accelerator.device) + progress_bar = tqdm( + range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" + ) + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 From c5b803ce94bd70812e6979ac7b986a769659b14e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Feb 2025 21:59:09 +0900 Subject: [PATCH 328/348] rng state management: Implement functions to get and set RNG states for consistent validation --- train_network.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index f0deb67ab..b3c7ff524 100644 --- a/train_network.py +++ b/train_network.py @@ -1278,6 +1278,31 @@ def remove_model(old_ckpt_name): original_args_min_timestep = args.min_timestep original_args_max_timestep = args.max_timestep + def get_rng_state() -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + cpu_rng_state = torch.get_rng_state() + if accelerator.device.type == "cuda": + gpu_rng_state = torch.cuda.get_rng_state() + elif accelerator.device.type == "xpu": + gpu_rng_state = torch.xpu.get_rng_state() + elif accelerator.device.type == "mps": + gpu_rng_state = torch.cuda.get_rng_state() + else: + gpu_rng_state = None + python_rng_state = random.getstate() + return (cpu_rng_state, gpu_rng_state, python_rng_state) + + def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): + cpu_rng_state, gpu_rng_state, python_rng_state = rng_states + torch.set_rng_state(cpu_rng_state) + if gpu_rng_state is not None: + if accelerator.device.type == "cuda": + torch.cuda.set_rng_state(gpu_rng_state) + elif accelerator.device.type == "xpu": + torch.xpu.set_rng_state(gpu_rng_state) + elif accelerator.device.type == "mps": + torch.cuda.set_rng_state(gpu_rng_state) + random.setstate(python_rng_state) + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 @@ -1391,7 +1416,7 @@ def remove_model(old_ckpt_name): if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_state = torch.get_rng_state() + rng_states = get_rng_state() torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( @@ -1453,7 +1478,7 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) - torch.set_rng_state(rng_state) + set_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() @@ -1470,7 +1495,7 @@ def remove_model(old_ckpt_name): if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_state = torch.get_rng_state() + rng_states = get_rng_state() torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( @@ -1536,7 +1561,7 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) - torch.set_rng_state(rng_state) + set_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() From a24db1d532a95cc9dd91aba25a06b8eb58db5cff Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Feb 2025 22:02:42 +0900 Subject: [PATCH 329/348] fix: validation timestep generation fails on SD/SDXL training --- library/train_util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 37ed0a994..01fa64674 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5935,7 +5935,10 @@ def save_sd_model_on_train_end_common( def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor: - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") + if min_timestep < max_timestep: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") + else: + timesteps = torch.full((b_size,), max_timestep, device="cpu") timesteps = timesteps.long().to(device) return timesteps From 0911683717e439676bba758a5f7a29356984966c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Feb 2025 20:53:49 +0900 Subject: [PATCH 330/348] set python random state --- train_network.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/train_network.py b/train_network.py index b3c7ff524..083e5993d 100644 --- a/train_network.py +++ b/train_network.py @@ -1278,7 +1278,7 @@ def remove_model(old_ckpt_name): original_args_min_timestep = args.min_timestep original_args_max_timestep = args.max_timestep - def get_rng_state() -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + def switch_rng_state(seed:int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: cpu_rng_state = torch.get_rng_state() if accelerator.device.type == "cuda": gpu_rng_state = torch.cuda.get_rng_state() @@ -1289,9 +1289,13 @@ def get_rng_state() -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple else: gpu_rng_state = None python_rng_state = random.getstate() + + torch.manual_seed(seed) + random.seed(seed) + return (cpu_rng_state, gpu_rng_state, python_rng_state) - def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): + def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): cpu_rng_state, gpu_rng_state, python_rng_state = rng_states torch.set_rng_state(cpu_rng_state) if gpu_rng_state is not None: @@ -1416,8 +1420,7 @@ def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor] if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_states = get_rng_state() - torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) + rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1478,7 +1481,7 @@ def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor] } accelerator.log(logs, step=global_step) - set_rng_state(rng_states) + restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() @@ -1495,8 +1498,7 @@ def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor] if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_states = get_rng_state() - torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) + rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1561,7 +1563,7 @@ def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor] } accelerator.log(logs, step=global_step) - set_rng_state(rng_states) + restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() From 344845b42941b48956dce94d614fbf32e900c70e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Feb 2025 21:25:40 +0900 Subject: [PATCH 331/348] fix: validation with block swap --- flux_train_network.py | 14 ++++++++++++-- sd3_train_network.py | 19 ++++++++++++++----- train_network.py | 18 +++++++++++------- 3 files changed, 37 insertions(+), 14 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 475bd751b..e97dfc5b8 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -36,7 +36,12 @@ def __init__(self): self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): super().assert_extra_args(args, train_dataset_group, val_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -341,7 +346,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -507,6 +512,11 @@ def forward(hidden_states): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module ) -> torch.nn.Module: diff --git a/sd3_train_network.py b/sd3_train_network.py index d4f131252..216d93c58 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -26,7 +26,12 @@ def __init__(self): super().__init__() self.sample_prompts_te_outputs = None - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): # super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -317,7 +322,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -445,15 +450,19 @@ def forward(hidden_states): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) - def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): - # TODO consider validation - # drop cached text encoder outputs + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True): + # drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) batch["text_encoder_outputs_list"] = text_encoder_outputs_list + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module ) -> torch.nn.Module: diff --git a/train_network.py b/train_network.py index 083e5993d..49013c708 100644 --- a/train_network.py +++ b/train_network.py @@ -309,7 +309,10 @@ def prepare_unet_with_accelerator( ) -> torch.nn.Module: return accelerator.prepare(unet) - def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True): + pass + + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): pass # endregion @@ -1278,7 +1281,7 @@ def remove_model(old_ckpt_name): original_args_min_timestep = args.min_timestep original_args_max_timestep = args.max_timestep - def switch_rng_state(seed:int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: cpu_rng_state = torch.get_rng_state() if accelerator.device.type == "cuda": gpu_rng_state = torch.cuda.get_rng_state() @@ -1330,8 +1333,8 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen with accelerator.accumulate(training_model): on_step_start_for_network(text_encoder, unet) - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + # preprocess batch for each model + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) loss = self.process_batch( batch, @@ -1434,8 +1437,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen break for timestep in validation_timesteps: - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep @@ -1471,6 +1473,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen } accelerator.log(logs, step=global_step) + self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 if is_tracking: @@ -1516,7 +1519,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen args.min_timestep = args.max_timestep = timestep # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) loss = self.process_batch( batch, @@ -1551,6 +1554,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen } accelerator.log(logs, step=global_step) + self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 if is_tracking: From 177203818a024329efa74640a588674323363373 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Feb 2025 21:42:46 +0900 Subject: [PATCH 332/348] fix: unpause training progress bar after vaidation --- train_network.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train_network.py b/train_network.py index 49013c708..8bfb19258 100644 --- a/train_network.py +++ b/train_network.py @@ -1489,6 +1489,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen args.max_timestep = original_args_max_timestep optimizer_train_fn() accelerator.unwrap_model(network).train() + progress_bar.unpause() if global_step >= args.max_train_steps: break @@ -1572,6 +1573,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen args.max_timestep = original_args_max_timestep optimizer_train_fn() accelerator.unwrap_model(network).train() + progress_bar.unpause() # END OF EPOCH if is_tracking: From cd80752175c663ede2cb7995da652ed5f5f7f749 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Feb 2025 21:42:58 +0900 Subject: [PATCH 333/348] fix: remove unused parameter 'accelerator' from encode_images_to_latents method --- flux_train_network.py | 2 +- sd3_train_network.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index e97dfc5b8..def441559 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -328,7 +328,7 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): diff --git a/sd3_train_network.py b/sd3_train_network.py index 216d93c58..cdb7aa4e3 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -304,7 +304,7 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): From 76b761943b5166f496aa1cb8ffbcc2d04469346a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Feb 2025 21:53:57 +0900 Subject: [PATCH 334/348] fix: simplify validation step condition in NetworkTrainer --- train_network.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/train_network.py b/train_network.py index 8bfb19258..99c58f49f 100644 --- a/train_network.py +++ b/train_network.py @@ -1414,12 +1414,9 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen ) accelerator.log(logs, step=global_step) - # VALIDATION PER STEP - should_validate_step = ( - args.validate_every_n_steps is not None - and global_step != 0 # Skip first step - and global_step % args.validate_every_n_steps == 0 - ) + # VALIDATION PER STEP: global_step is already incremented + # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ... + should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0 if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() From ab88b431b0c903f7a60ae59e22fbb8a7cf9d78a1 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 14 Feb 2025 11:14:38 -0500 Subject: [PATCH 335/348] Fix validation epoch divergence --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index c3879531d..b5f92e06b 100644 --- a/train_network.py +++ b/train_network.py @@ -1498,7 +1498,7 @@ def remove_model(old_ckpt_name): if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss + loss_validation_divergence = val_epoch_loss_recorder.moving_average - avr_loss logs = { "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, From 4671e237781dcfe9a16e90f5343afd57586a1df6 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 16 Feb 2025 01:42:44 -0500 Subject: [PATCH 336/348] Fix validation epoch loss to check epoch average --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index b5f92e06b..674f1cb66 100644 --- a/train_network.py +++ b/train_network.py @@ -1498,7 +1498,7 @@ def remove_model(old_ckpt_name): if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_epoch_loss_recorder.moving_average - avr_loss + loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average logs = { "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, From 3c7496ae3f2736a8283a881f49698d3e8f3a4291 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 16 Feb 2025 22:18:14 -0500 Subject: [PATCH 337/348] Fix sizes for validation split --- library/train_util.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 37ed0a994..6c782ea1c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -148,10 +148,11 @@ def split_train_val( paths: List[str], + sizes: List[Optional[Tuple[int, int]]], is_training_dataset: bool, validation_split: float, validation_seed: int | None -) -> List[str]: +) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]: """ Split the dataset into train and validation @@ -172,10 +173,12 @@ def split_train_val( # Split the dataset between training and validation if is_training_dataset: # Training dataset we split to the first part - return paths[0:math.ceil(len(paths) * (1 - validation_split))] + split = math.ceil(len(paths) * (1 - validation_split)) + return paths[0:split], sizes[0:split] else: # Validation dataset we split to the second part - return paths[len(paths) - round(len(paths) * validation_split):] + split = len(paths) - round(len(paths) * validation_split) + return paths[split:], sizes[split:] class ImageInfo: @@ -1931,12 +1934,12 @@ def load_dreambooth_dir(subset: DreamBoothSubset): with open(info_cache_file, "r", encoding="utf-8") as f: metas = json.load(f) img_paths = list(metas.keys()) - sizes = [meta["resolution"] for meta in metas.values()] + sizes: List[Optional[Tuple[int, int]]] = [meta["resolution"] for meta in metas.values()] # we may need to check image size and existence of image files, but it takes time, so user should check it before training else: img_paths = glob_images(subset.image_dir, "*") - sizes = [None] * len(img_paths) + sizes: List[Optional[Tuple[int, int]]] = [None] * len(img_paths) # new caching: get image size from cache files strategy = LatentsCachingStrategy.get_strategy() @@ -1969,7 +1972,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): w, h = None, None if w is not None and h is not None: - sizes[i] = [w, h] + sizes[i] = (w, h) size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") @@ -1990,8 +1993,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # Otherwise the img_paths remain as original img_paths and no split # required for training images dataset of regularization images else: - img_paths = split_train_val( + img_paths, sizes = split_train_val( img_paths, + sizes, self.is_training_dataset, self.validation_split, self.validation_seed From f3a010978c0e4b88c4839b3a81400b8973f52158 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 16 Feb 2025 22:28:34 -0500 Subject: [PATCH 338/348] Clear sizes for validation reg images to be consistent --- library/train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/train_util.py b/library/train_util.py index 6c782ea1c..39b4af856 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1990,6 +1990,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # Skip any validation dataset for regularization images if self.is_training_dataset is False: img_paths = [] + sizes = [] # Otherwise the img_paths remain as original img_paths and no split # required for training images dataset of regularization images else: From 9436b410617f22716eac64f7c604c8f53fa8c1a8 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 17 Feb 2025 14:28:41 -0500 Subject: [PATCH 339/348] Fix validation split and add test --- library/train_util.py | 8 ++++++-- tests/test_validation.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) create mode 100644 tests/test_validation.py diff --git a/library/train_util.py b/library/train_util.py index 39b4af856..b23290663 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -161,15 +161,19 @@ def split_train_val( [0:80] = 80 training images [80:] = 20 validation images """ + dataset = list(zip(paths, sizes)) if validation_seed is not None: logging.info(f"Using validation seed: {validation_seed}") prevstate = random.getstate() random.seed(validation_seed) - random.shuffle(paths) + random.shuffle(dataset) random.setstate(prevstate) else: - random.shuffle(paths) + random.shuffle(dataset) + paths, sizes = zip(*dataset) + paths = list(paths) + sizes = list(sizes) # Split the dataset between training and validation if is_training_dataset: # Training dataset we split to the first part diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 000000000..f80686d8c --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,17 @@ +from library.train_util import split_train_val + + +def test_split_train_val(): + paths = ["path1", "path2", "path3", "path4", "path5", "path6", "path7"] + sizes = [(1, 1), (2, 2), None, (4, 4), (5, 5), (6, 6), None] + result_paths, result_sizes = split_train_val(paths, sizes, True, 0.2, 1234) + assert result_paths == ["path2", "path3", "path6", "path5", "path1", "path4"], result_paths + assert result_sizes == [(2, 2), None, (6, 6), (5, 5), (1, 1), (4, 4)], result_sizes + + result_paths, result_sizes = split_train_val(paths, sizes, False, 0.2, 1234) + assert result_paths == ["path7"], result_paths + assert result_sizes == [None], result_sizes + + +if __name__ == "__main__": + test_split_train_val() From 4a369961346ca153a370728247449978d8a33415 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 18 Feb 2025 22:05:08 +0900 Subject: [PATCH 340/348] modify log step calculation --- train_network.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/train_network.py b/train_network.py index 47c4bb56e..93558da45 100644 --- a/train_network.py +++ b/train_network.py @@ -1464,11 +1464,10 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen ) if is_tracking: - logs = { - "loss/validation/step_current": current_loss, - "val_step": (epoch * validation_total_steps) + val_ts_step, - } - accelerator.log(logs, step=global_step) + logs = {"loss/validation/step_current": current_loss} + accelerator.log( + logs, step=global_step + val_ts_step + ) # a bit weird to log with global_step + val_ts_step self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 @@ -1545,25 +1544,20 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen ) if is_tracking: - logs = { - "loss/validation/epoch_current": current_loss, - "epoch": epoch + 1, - "val_step": (epoch * validation_total_steps) + val_ts_step, - } - accelerator.log(logs, step=global_step) + logs = {"loss/validation/epoch_current": current_loss} + accelerator.log(logs, step=global_step + val_ts_step) self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average + loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average logs = { "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, - "epoch": epoch + 1, } - accelerator.log(logs, step=global_step) + accelerator.log(logs, step=epoch + 1) restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep @@ -1574,8 +1568,8 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen # END OF EPOCH if is_tracking: - logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} - accelerator.log(logs, step=global_step) + logs = {"loss/epoch_average": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() From efb2a128cd0d2c6340a21bf544e77853a20b3453 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 21 Feb 2025 22:07:35 +0900 Subject: [PATCH 341/348] fix wandb val logging --- library/train_util.py | 57 +++++++++++++++------------------ train_network.py | 73 ++++++++++++++++++++++++++++++++----------- 2 files changed, 80 insertions(+), 50 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 258701982..1f591c422 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -13,17 +13,7 @@ import shutil import time import typing -from typing import ( - Any, - Callable, - Dict, - List, - NamedTuple, - Optional, - Sequence, - Tuple, - Union -) +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob import math @@ -146,12 +136,13 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" + def split_train_val( - paths: List[str], + paths: List[str], sizes: List[Optional[Tuple[int, int]]], - is_training_dataset: bool, - validation_split: float, - validation_seed: int | None + is_training_dataset: bool, + validation_split: float, + validation_seed: int | None, ) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]: """ Split the dataset into train and validation @@ -1842,7 +1833,7 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): class DreamBoothDataset(BaseDataset): IMAGE_INFO_CACHE_FILE = "metadata_cache.json" - # The is_training_dataset defines the type of dataset, training or validation + # The is_training_dataset defines the type of dataset, training or validation # if is_training_dataset is True -> training dataset # if is_training_dataset is False -> validation dataset def __init__( @@ -1981,29 +1972,25 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") # We want to create a training and validation split. This should be improved in the future - # to allow a clearer distinction between training and validation. This can be seen as a + # to allow a clearer distinction between training and validation. This can be seen as a # short-term solution to limit what is necessary to implement validation datasets - # + # # We split the dataset for the subset based on if we are doing a validation split - # The self.is_training_dataset defines the type of dataset, training or validation + # The self.is_training_dataset defines the type of dataset, training or validation # if self.is_training_dataset is True -> training dataset # if self.is_training_dataset is False -> validation dataset if self.validation_split > 0.0: - # For regularization images we do not want to split this dataset. + # For regularization images we do not want to split this dataset. if subset.is_reg is True: # Skip any validation dataset for regularization images if self.is_training_dataset is False: img_paths = [] sizes = [] - # Otherwise the img_paths remain as original img_paths and no split + # Otherwise the img_paths remain as original img_paths and no split # required for training images dataset of regularization images else: img_paths, sizes = split_train_val( - img_paths, - sizes, - self.is_training_dataset, - self.validation_split, - self.validation_seed + img_paths, sizes, self.is_training_dataset, self.validation_split, self.validation_seed ) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") @@ -2373,7 +2360,7 @@ def __init__( bucket_no_upscale: bool, debug_dataset: bool, validation_split: float, - validation_seed: Optional[int], + validation_seed: Optional[int], ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2431,9 +2418,9 @@ def __init__( self.image_data = self.dreambooth_dataset_delegate.image_data self.batch_size = batch_size self.num_train_images = self.dreambooth_dataset_delegate.num_train_images - self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images self.validation_split = validation_split - self.validation_seed = validation_seed + self.validation_seed = validation_seed # assert all conditioning data exists missing_imgs = [] @@ -5952,7 +5939,9 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor return timesteps -def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: +def get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents: torch.FloatTensor +) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: @@ -6444,7 +6433,7 @@ def sample_image_inference( wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption -def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): +def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): """ Initialize experiment trackers with tracker specific behaviors """ @@ -6461,13 +6450,17 @@ def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tr ) if "wandb" in [tracker.name for tracker in accelerator.trackers]: - import wandb + import wandb + wandb_tracker = accelerator.get_tracker("wandb", unwrap=True) # Define specific metrics to handle validation and epochs "steps" wandb_tracker.define_metric("epoch", hidden=True) wandb_tracker.define_metric("val_step", hidden=True) + wandb_tracker.define_metric("global_step", hidden=True) + + # endregion diff --git a/train_network.py b/train_network.py index 93558da45..ab5483deb 100644 --- a/train_network.py +++ b/train_network.py @@ -119,6 +119,45 @@ def generate_step_logs( return logs + def step_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int): + self.accelerator_logging(accelerator, logs, global_step, global_step, epoch) + + def epoch_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int): + self.accelerator_logging(accelerator, logs, epoch, global_step, epoch) + + def val_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int, val_step: int): + self.accelerator_logging(accelerator, logs, global_step + val_step, global_step, epoch, val_step) + + def accelerator_logging( + self, accelerator: Accelerator, logs: dict, step_value: int, global_step: int, epoch: int, val_step: Optional[int] = None + ): + """ + step_value is for tensorboard, other values are for wandb + """ + tensorboard_tracker = None + wandb_tracker = None + other_trackers = [] + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + tensorboard_tracker = accelerator.get_tracker("tensorboard") + elif tracker.name == "wandb": + wandb_tracker = accelerator.get_tracker("wandb") + else: + other_trackers.append(accelerator.get_tracker(tracker.name)) + + if tensorboard_tracker is not None: + tensorboard_tracker.log(logs, step=step_value) + + if wandb_tracker is not None: + logs["global_step"] = global_step + logs["epoch"] = epoch + if val_step is not None: + logs["val_step"] = val_step + wandb_tracker.log(logs) + + for tracker in other_trackers: + tracker.log(logs, step=step_value) + def assert_extra_args( self, args, @@ -1412,7 +1451,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen logs = self.generate_step_logs( args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) - accelerator.log(logs, step=global_step) + self.step_logging(accelerator, logs, global_step, epoch + 1) # VALIDATION PER STEP: global_step is already incremented # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ... @@ -1428,7 +1467,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen disable=not accelerator.is_local_main_process, desc="validation steps", ) - val_ts_step = 0 + val_timesteps_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break @@ -1457,20 +1496,18 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen ) current_loss = loss.detach().item() - val_step_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) + val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} ) - if is_tracking: - logs = {"loss/validation/step_current": current_loss} - accelerator.log( - logs, step=global_step + val_ts_step - ) # a bit weird to log with global_step + val_ts_step + # if is_tracking: + # logs = {f"loss/validation/step_current_{timestep}": current_loss} + # self.val_logging(accelerator, logs, global_step, epoch + 1, val_step) self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - val_ts_step += 1 + val_timesteps_step += 1 if is_tracking: loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average @@ -1478,7 +1515,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen "loss/validation/step_average": val_step_loss_recorder.moving_average, "loss/validation/step_divergence": loss_validation_divergence, } - accelerator.log(logs, step=global_step) + self.step_logging(accelerator, logs, global_step, epoch=epoch + 1) restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep @@ -1507,7 +1544,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen desc="epoch validation steps", ) - val_ts_step = 0 + val_timesteps_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break @@ -1537,18 +1574,18 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen ) current_loss = loss.detach().item() - val_epoch_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) + val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep} ) - if is_tracking: - logs = {"loss/validation/epoch_current": current_loss} - accelerator.log(logs, step=global_step + val_ts_step) + # if is_tracking: + # logs = {f"loss/validation/epoch_current_{timestep}": current_loss} + # self.val_logging(accelerator, logs, global_step, epoch + 1, val_step) self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - val_ts_step += 1 + val_timesteps_step += 1 if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average @@ -1557,7 +1594,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, } - accelerator.log(logs, step=epoch + 1) + self.epoch_logging(accelerator, logs, global_step, epoch + 1) restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep @@ -1569,7 +1606,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen # END OF EPOCH if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) + self.epoch_logging(accelerator, logs, global_step, epoch + 1) accelerator.wait_for_everyone() From f4a004786500d80e1b47728d216aed9d76869a9d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Feb 2025 20:50:44 +0900 Subject: [PATCH 342/348] feat: support metadata loading in MemoryEfficientSafeOpen --- library/utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/library/utils.py b/library/utils.py index 07079c6d9..4df8bd328 100644 --- a/library/utils.py +++ b/library/utils.py @@ -261,11 +261,10 @@ def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: class MemoryEfficientSafeOpen: - # does not support metadata loading def __init__(self, filename): self.filename = filename - self.header, self.header_size = self._read_header() self.file = open(filename, "rb") + self.header, self.header_size = self._read_header() def __enter__(self): return self @@ -276,6 +275,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): def keys(self): return [k for k in self.header.keys() if k != "__metadata__"] + def metadata(self) -> Dict[str, str]: + return self.header.get("__metadata__", {}) + def get_tensor(self, key): if key not in self.header: raise KeyError(f"Tensor '{key}' not found in the file") @@ -293,10 +295,9 @@ def get_tensor(self, key): return self._deserialize_tensor(tensor_bytes, metadata) def _read_header(self): - with open(self.filename, "rb") as f: - header_size = struct.unpack(" Date: Wed, 26 Feb 2025 20:50:58 +0900 Subject: [PATCH 343/348] feat: add script to merge multiple safetensors files into a single file for SD3 --- tools/merge_sd3_safetensors.py | 139 +++++++++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 tools/merge_sd3_safetensors.py diff --git a/tools/merge_sd3_safetensors.py b/tools/merge_sd3_safetensors.py new file mode 100644 index 000000000..bef7c9b90 --- /dev/null +++ b/tools/merge_sd3_safetensors.py @@ -0,0 +1,139 @@ +import argparse +import os +import gc +from typing import Dict, Optional, Union +import torch +from safetensors.torch import safe_open + +from library.utils import setup_logging +from library.utils import load_safetensors, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def merge_safetensors( + dit_path: str, + vae_path: Optional[str] = None, + clip_l_path: Optional[str] = None, + clip_g_path: Optional[str] = None, + t5xxl_path: Optional[str] = None, + output_path: str = "merged_model.safetensors", + device: str = "cpu", +): + """ + Merge multiple safetensors files into a single file + + Args: + dit_path: Path to the DiT/MMDiT model + vae_path: Path to the VAE model + clip_l_path: Path to the CLIP-L model + clip_g_path: Path to the CLIP-G model + t5xxl_path: Path to the T5-XXL model + output_path: Path to save the merged model + device: Device to load tensors to + """ + logger.info("Starting to merge safetensors files...") + + # 1. Get DiT metadata if available + metadata = None + try: + with safe_open(dit_path, framework="pt") as f: + metadata = f.metadata() # may be None + if metadata: + logger.info(f"Found metadata in DiT model: {metadata}") + except Exception as e: + logger.warning(f"Failed to read metadata from DiT model: {e}") + + # 2. Create empty merged state dict + merged_state_dict = {} + + # 3. Load and merge each model with memory management + + # DiT/MMDiT - prefix: model.diffusion_model. + logger.info(f"Loading DiT model from {dit_path}") + dit_state_dict = load_safetensors(dit_path, device=device, disable_mmap=True) + logger.info(f"Adding DiT model with {len(dit_state_dict)} keys") + for key, value in dit_state_dict.items(): + merged_state_dict[f"model.diffusion_model.{key}"] = value + # Free memory + del dit_state_dict + gc.collect() + + # VAE - prefix: first_stage_model. + if vae_path: + logger.info(f"Loading VAE model from {vae_path}") + vae_state_dict = load_safetensors(vae_path, device=device, disable_mmap=True) + logger.info(f"Adding VAE model with {len(vae_state_dict)} keys") + for key, value in vae_state_dict.items(): + merged_state_dict[f"first_stage_model.{key}"] = value + # Free memory + del vae_state_dict + gc.collect() + + # CLIP-L - prefix: text_encoders.clip_l. + if clip_l_path: + logger.info(f"Loading CLIP-L model from {clip_l_path}") + clip_l_state_dict = load_safetensors(clip_l_path, device=device, disable_mmap=True) + logger.info(f"Adding CLIP-L model with {len(clip_l_state_dict)} keys") + for key, value in clip_l_state_dict.items(): + merged_state_dict[f"text_encoders.clip_l.{key}"] = value + # Free memory + del clip_l_state_dict + gc.collect() + + # CLIP-G - prefix: text_encoders.clip_g. + if clip_g_path: + logger.info(f"Loading CLIP-G model from {clip_g_path}") + clip_g_state_dict = load_safetensors(clip_g_path, device=device, disable_mmap=True) + logger.info(f"Adding CLIP-G model with {len(clip_g_state_dict)} keys") + for key, value in clip_g_state_dict.items(): + merged_state_dict[f"text_encoders.clip_g.{key}"] = value + # Free memory + del clip_g_state_dict + gc.collect() + + # T5-XXL - prefix: text_encoders.t5xxl. + if t5xxl_path: + logger.info(f"Loading T5-XXL model from {t5xxl_path}") + t5xxl_state_dict = load_safetensors(t5xxl_path, device=device, disable_mmap=True) + logger.info(f"Adding T5-XXL model with {len(t5xxl_state_dict)} keys") + for key, value in t5xxl_state_dict.items(): + merged_state_dict[f"text_encoders.t5xxl.{key}"] = value + # Free memory + del t5xxl_state_dict + gc.collect() + + # 4. Save merged state dict + logger.info(f"Saving merged model to {output_path} with {len(merged_state_dict)} keys total") + mem_eff_save_file(merged_state_dict, output_path, metadata) + logger.info("Successfully merged safetensors files") + + +def main(): + parser = argparse.ArgumentParser(description="Merge Stable Diffusion 3.5 model components into a single safetensors file") + parser.add_argument("--dit", required=True, help="Path to the DiT/MMDiT model") + parser.add_argument("--vae", help="Path to the VAE model") + parser.add_argument("--clip_l", help="Path to the CLIP-L model") + parser.add_argument("--clip_g", help="Path to the CLIP-G model") + parser.add_argument("--t5xxl", help="Path to the T5-XXL model") + parser.add_argument("--output", default="merged_model.safetensors", help="Path to save the merged model") + parser.add_argument("--device", default="cpu", help="Device to load tensors to") + + args = parser.parse_args() + + merge_safetensors( + dit_path=args.dit, + vae_path=args.vae, + clip_l_path=args.clip_l, + clip_g_path=args.clip_g, + t5xxl_path=args.t5xxl, + output_path=args.output, + device=args.device, + ) + + +if __name__ == "__main__": + main() From ae409e83c939f2c4a997cfb1679bd7cd364baf7e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Feb 2025 20:56:32 +0900 Subject: [PATCH 344/348] fix: FLUX/SD3 network training not working without caching latents closes #1954 --- flux_train_network.py | 11 ++++++++--- sd3_train_network.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index ae4b62f5c..26503df1f 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -36,7 +36,12 @@ def __init__(self): self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): super().assert_extra_args(args, train_dataset_group, val_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -323,7 +328,7 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): @@ -341,7 +346,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) diff --git a/sd3_train_network.py b/sd3_train_network.py index 2f4579492..9438bc7bc 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -26,7 +26,12 @@ def __init__(self): super().__init__() self.sample_prompts_te_outputs = None - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): # super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -299,7 +304,7 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): @@ -317,7 +322,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) From 3d79239be4b20d67faed67c47f693396342e3af4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Feb 2025 21:21:04 +0900 Subject: [PATCH 345/348] docs: update README to include recent improvements in validation loss calculation --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 4bbd7617e..3c6993075 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,11 @@ The command to install PyTorch is as follows: ### Recent Updates +Feb 26, 2025: + +- Improve the validation loss calculation in `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py`. PR [#1903](https://github.com/kohya-ss/sd-scripts/pull/1903) + - The validation loss uses the fixed timestep sampling and the fixed random seed. This is to ensure that the validation loss is not fluctuated by the random values. + Jan 25, 2025: - `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO! From 734333d0c9eec3f20582c9c16f6d148cb1ec2596 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 28 Feb 2025 23:52:29 +0900 Subject: [PATCH 346/348] feat: enhance merging logic for safetensors models to handle key prefixes correctly --- tools/merge_sd3_safetensors.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/tools/merge_sd3_safetensors.py b/tools/merge_sd3_safetensors.py index bef7c9b90..960cf6e77 100644 --- a/tools/merge_sd3_safetensors.py +++ b/tools/merge_sd3_safetensors.py @@ -53,22 +53,30 @@ def merge_safetensors( # 3. Load and merge each model with memory management # DiT/MMDiT - prefix: model.diffusion_model. + # This state dict may have VAE keys. logger.info(f"Loading DiT model from {dit_path}") dit_state_dict = load_safetensors(dit_path, device=device, disable_mmap=True) logger.info(f"Adding DiT model with {len(dit_state_dict)} keys") for key, value in dit_state_dict.items(): - merged_state_dict[f"model.diffusion_model.{key}"] = value + if key.startswith("model.diffusion_model.") or key.startswith("first_stage_model."): + merged_state_dict[key] = value + else: + merged_state_dict[f"model.diffusion_model.{key}"] = value # Free memory del dit_state_dict gc.collect() # VAE - prefix: first_stage_model. + # May be omitted if VAE is already included in DiT model. if vae_path: logger.info(f"Loading VAE model from {vae_path}") vae_state_dict = load_safetensors(vae_path, device=device, disable_mmap=True) logger.info(f"Adding VAE model with {len(vae_state_dict)} keys") for key, value in vae_state_dict.items(): - merged_state_dict[f"first_stage_model.{key}"] = value + if key.startswith("first_stage_model."): + merged_state_dict[key] = value + else: + merged_state_dict[f"first_stage_model.{key}"] = value # Free memory del vae_state_dict gc.collect() @@ -79,7 +87,10 @@ def merge_safetensors( clip_l_state_dict = load_safetensors(clip_l_path, device=device, disable_mmap=True) logger.info(f"Adding CLIP-L model with {len(clip_l_state_dict)} keys") for key, value in clip_l_state_dict.items(): - merged_state_dict[f"text_encoders.clip_l.{key}"] = value + if key.startswith("text_encoders.clip_l.transformer."): + merged_state_dict[key] = value + else: + merged_state_dict[f"text_encoders.clip_l.transformer.{key}"] = value # Free memory del clip_l_state_dict gc.collect() @@ -90,7 +101,10 @@ def merge_safetensors( clip_g_state_dict = load_safetensors(clip_g_path, device=device, disable_mmap=True) logger.info(f"Adding CLIP-G model with {len(clip_g_state_dict)} keys") for key, value in clip_g_state_dict.items(): - merged_state_dict[f"text_encoders.clip_g.{key}"] = value + if key.startswith("text_encoders.clip_g.transformer."): + merged_state_dict[key] = value + else: + merged_state_dict[f"text_encoders.clip_g.transformer.{key}"] = value # Free memory del clip_g_state_dict gc.collect() @@ -101,7 +115,10 @@ def merge_safetensors( t5xxl_state_dict = load_safetensors(t5xxl_path, device=device, disable_mmap=True) logger.info(f"Adding T5-XXL model with {len(t5xxl_state_dict)} keys") for key, value in t5xxl_state_dict.items(): - merged_state_dict[f"text_encoders.t5xxl.{key}"] = value + if key.startswith("text_encoders.t5xxl.transformer."): + merged_state_dict[key] = value + else: + merged_state_dict[f"text_encoders.t5xxl.transformer.{key}"] = value # Free memory del t5xxl_state_dict gc.collect() @@ -115,7 +132,7 @@ def merge_safetensors( def main(): parser = argparse.ArgumentParser(description="Merge Stable Diffusion 3.5 model components into a single safetensors file") parser.add_argument("--dit", required=True, help="Path to the DiT/MMDiT model") - parser.add_argument("--vae", help="Path to the VAE model") + parser.add_argument("--vae", help="Path to the VAE model. May be omitted if VAE is included in DiT model") parser.add_argument("--clip_l", help="Path to the CLIP-L model") parser.add_argument("--clip_g", help="Path to the CLIP-G model") parser.add_argument("--t5xxl", help="Path to the T5-XXL model") From ba5251168a91f608de9fe9e365a2f889e4bb6cf8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 1 Mar 2025 10:31:39 +0900 Subject: [PATCH 347/348] fix: save tensors as is dtype, add save_precision option --- tools/merge_sd3_safetensors.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tools/merge_sd3_safetensors.py b/tools/merge_sd3_safetensors.py index 960cf6e77..6bc1003ec 100644 --- a/tools/merge_sd3_safetensors.py +++ b/tools/merge_sd3_safetensors.py @@ -6,7 +6,7 @@ from safetensors.torch import safe_open from library.utils import setup_logging -from library.utils import load_safetensors, mem_eff_save_file +from library.utils import load_safetensors, mem_eff_save_file, str_to_dtype setup_logging() import logging @@ -22,6 +22,7 @@ def merge_safetensors( t5xxl_path: Optional[str] = None, output_path: str = "merged_model.safetensors", device: str = "cpu", + save_precision: Optional[str] = None, ): """ Merge multiple safetensors files into a single file @@ -34,9 +35,16 @@ def merge_safetensors( t5xxl_path: Path to the T5-XXL model output_path: Path to save the merged model device: Device to load tensors to + save_precision: Target dtype for model weights (e.g. 'fp16', 'bf16') """ logger.info("Starting to merge safetensors files...") + # Convert save_precision string to torch dtype if specified + if save_precision: + target_dtype = str_to_dtype(save_precision) + else: + target_dtype = None + # 1. Get DiT metadata if available metadata = None try: @@ -55,7 +63,7 @@ def merge_safetensors( # DiT/MMDiT - prefix: model.diffusion_model. # This state dict may have VAE keys. logger.info(f"Loading DiT model from {dit_path}") - dit_state_dict = load_safetensors(dit_path, device=device, disable_mmap=True) + dit_state_dict = load_safetensors(dit_path, device=device, disable_mmap=True, dtype=target_dtype) logger.info(f"Adding DiT model with {len(dit_state_dict)} keys") for key, value in dit_state_dict.items(): if key.startswith("model.diffusion_model.") or key.startswith("first_stage_model."): @@ -70,7 +78,7 @@ def merge_safetensors( # May be omitted if VAE is already included in DiT model. if vae_path: logger.info(f"Loading VAE model from {vae_path}") - vae_state_dict = load_safetensors(vae_path, device=device, disable_mmap=True) + vae_state_dict = load_safetensors(vae_path, device=device, disable_mmap=True, dtype=target_dtype) logger.info(f"Adding VAE model with {len(vae_state_dict)} keys") for key, value in vae_state_dict.items(): if key.startswith("first_stage_model."): @@ -84,7 +92,7 @@ def merge_safetensors( # CLIP-L - prefix: text_encoders.clip_l. if clip_l_path: logger.info(f"Loading CLIP-L model from {clip_l_path}") - clip_l_state_dict = load_safetensors(clip_l_path, device=device, disable_mmap=True) + clip_l_state_dict = load_safetensors(clip_l_path, device=device, disable_mmap=True, dtype=target_dtype) logger.info(f"Adding CLIP-L model with {len(clip_l_state_dict)} keys") for key, value in clip_l_state_dict.items(): if key.startswith("text_encoders.clip_l.transformer."): @@ -98,7 +106,7 @@ def merge_safetensors( # CLIP-G - prefix: text_encoders.clip_g. if clip_g_path: logger.info(f"Loading CLIP-G model from {clip_g_path}") - clip_g_state_dict = load_safetensors(clip_g_path, device=device, disable_mmap=True) + clip_g_state_dict = load_safetensors(clip_g_path, device=device, disable_mmap=True, dtype=target_dtype) logger.info(f"Adding CLIP-G model with {len(clip_g_state_dict)} keys") for key, value in clip_g_state_dict.items(): if key.startswith("text_encoders.clip_g.transformer."): @@ -112,7 +120,7 @@ def merge_safetensors( # T5-XXL - prefix: text_encoders.t5xxl. if t5xxl_path: logger.info(f"Loading T5-XXL model from {t5xxl_path}") - t5xxl_state_dict = load_safetensors(t5xxl_path, device=device, disable_mmap=True) + t5xxl_state_dict = load_safetensors(t5xxl_path, device=device, disable_mmap=True, dtype=target_dtype) logger.info(f"Adding T5-XXL model with {len(t5xxl_state_dict)} keys") for key, value in t5xxl_state_dict.items(): if key.startswith("text_encoders.t5xxl.transformer."): @@ -138,6 +146,7 @@ def main(): parser.add_argument("--t5xxl", help="Path to the T5-XXL model") parser.add_argument("--output", default="merged_model.safetensors", help="Path to save the merged model") parser.add_argument("--device", default="cpu", help="Device to load tensors to") + parser.add_argument("--save_precision", type=str, help="Precision to save the model in (e.g., 'fp16', 'bf16', 'float16', etc.)") args = parser.parse_args() @@ -149,6 +158,7 @@ def main(): t5xxl_path=args.t5xxl, output_path=args.output, device=args.device, + save_precision=args.save_precision, ) From aa2bde7ece17be16083acfe9645bb4e21718fb2c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 5 Mar 2025 23:24:52 +0900 Subject: [PATCH 348/348] docs: add utility script for merging SD3 weights into a single .safetensors file --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 3c6993075..426eaed82 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Mar 6, 2025: + +- Added a utility script to merge the weights of SD3's DiT, VAE (optional), CLIP-L, CLIP-G, and T5XXL into a single .safetensors file. Run `tools/merge_sd3_safetensors.py`. See `--help` for usage. PR [#1960](https://github.com/kohya-ss/sd-scripts/pull/1960) + Feb 26, 2025: - Improve the validation loss calculation in `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py`. PR [#1903](https://github.com/kohya-ss/sd-scripts/pull/1903)