From a9b64ffba8efbb0991a094e38b1f5d5c56680caf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 27 Feb 2024 21:43:55 +0900 Subject: [PATCH] support masked loss in sdxl_train ref #589 --- README.md | 4 +++- sdxl_train.py | 20 +++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9cc79cc09..354983c38 100644 --- a/README.md +++ b/README.md @@ -251,7 +251,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ### Masked loss -`train_network.py` and `sdxl_train_network.py` now support the masked loss. `--masked_loss` option is added. +`train_network.py`, `sdxl_train_network.py` and `sdxl_train.py` now support the masked loss. `--masked_loss` option is added. + +NOTE: `train_network.py` and `sdxl_train.py` are not tested yet. ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset). diff --git a/sdxl_train.py b/sdxl_train.py index e0df263d6..448a160f6 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -11,6 +11,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate.utils import set_seed @@ -124,7 +125,7 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -579,6 +580,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): ): # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + + if args.masked_loss: + # mask image is -1 to 1. we need to convert it to 0 to 1 + mask_image = batch["conditioning_images"].to(dtype=weight_dtype)[:, 0].unsqueeze(1) # use R channel + + # resize to the same size as the loss + mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") + mask_image = mask_image / 2 + 0.5 + loss = loss * mask_image + loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -780,6 +791,13 @@ def setup_parser() -> argparse.ArgumentParser: + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", ) + # TODO common masked_loss argument + parser.add_argument( + "--masked_loss", + action="store_true", + help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", + ) + return parser