Skip to content

Commit

Permalink
support masked loss in sdxl_train ref #589
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Feb 27, 2024
1 parent 4a5546d commit a9b64ff
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum

### Masked loss

`train_network.py` and `sdxl_train_network.py` now support the masked loss. `--masked_loss` option is added.
`train_network.py`, `sdxl_train_network.py` and `sdxl_train.py` now support the masked loss. `--masked_loss` option is added.

NOTE: `train_network.py` and `sdxl_train.py` are not tested yet.

ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).

Expand Down
20 changes: 19 additions & 1 deletion sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()

from accelerate.utils import set_seed
Expand Down Expand Up @@ -124,7 +125,7 @@ def train(args):

# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
Expand Down Expand Up @@ -579,6 +580,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")

if args.masked_loss:
# mask image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=weight_dtype)[:, 0].unsqueeze(1) # use R channel

# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
mask_image = mask_image / 2 + 0.5
loss = loss * mask_image

loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
Expand Down Expand Up @@ -780,6 +791,13 @@ def setup_parser() -> argparse.ArgumentParser:
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
)

# TODO common masked_loss argument
parser.add_argument(
"--masked_loss",
action="store_true",
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
)

return parser


Expand Down

0 comments on commit a9b64ff

Please sign in to comment.