diff --git a/README.md b/README.md index 0cecc5676..5282c1f69 100644 --- a/README.md +++ b/README.md @@ -150,14 +150,15 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Also, if there is an absolute path, the path may be exposed, so it is recommended to specify a relative path or write it in the configuration file. In such cases, an INFO log is displayed. - See [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) and PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) for details. - Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging. -- Other improvements include the addition of masked loss, DeepSpeed support, dataset settings improvements, and image tagging improvements. See below for details. +- Other improvements include the addition of masked loss, scheduled Huber Loss, DeepSpeed support, dataset settings improvements, and image tagging improvements. See below for details. #### Training scripts - `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`). - Fixed a bug that U-Net and Text Encoders are included in the state in `train_network.py` and `sdxl_train_network.py`. The saving and loading of the state are faster, the file size is smaller, and the memory usage when loading is reduced. - DeepSpeed is supported. PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) and [#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) Thanks to BootsofLagrangian! See PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) for details. -- The masked loss is supported in each training script. PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) See [Masked loss](#masked-loss) for details. +- The masked loss is supported in each training script. PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) See [Masked loss](#about-masked-loss) for details. +- Scheduled Huber Loss has been introduced to each training scripts. PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) Thanks to kabachuha for the PR and cheald, drhead, and others for the discussion! See [Scheduled Huber Loss](#about-scheduled-huber-loss) for details. - The options `--noise_offset_random_strength` and `--ip_noise_gamma_random_strength` are added to each training script. These options can be used to vary the noise offset and ip noise gamma in the range of 0 to the specified value. PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) Thanks to KohakuBlueleaf! - The options `--save_state_on_train_end` are added to each training script. PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) Thanks to gesen2egee! - The options `--sample_every_n_epochs` and `--sample_every_n_steps` in each training script now display a warning and ignore them when a number less than or equal to `0` is specified. Thanks to S-Del for raising the issue. @@ -199,6 +200,23 @@ The feature is not fully tested, so there may be bugs. If you find any issues, p ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. The pixel values 0-255 are converted to 0-1 (i.e., the pixel value 128 is treated as the half weight of the loss). See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset). +#### About Scheduled Huber Loss + +Scheduled Huber Loss has been introduced to each training scripts. This is a method to improve robustness against outliers or anomalies (data corruption) in the training data. + +With the traditional MSE (L2) loss function, the impact of outliers could be significant, potentially leading to a degradation in the quality of generated images. On the other hand, while the Huber loss function can suppress the influence of outliers, it tends to compromise the reproduction of fine details in images. + +To address this, the proposed method employs a clever application of the Huber loss function. By scheduling the use of Huber loss in the early stages of training (when noise is high) and MSE in the later stages, it strikes a balance between outlier robustness and fine detail reproduction. + +Experimental results have confirmed that this method achieves higher accuracy on data containing outliers compared to pure Huber loss or MSE. The increase in computational cost is minimal. + +The newly added arguments loss_type, huber_schedule, and huber_c allow for the selection of the loss function type (Huber, smooth L1, MSE), scheduling method (exponential, constant, SNR), and Huber's parameter. This enables optimization based on the characteristics of the dataset. + +See PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) for details. + +- `loss_type`: Specify the loss function type. Choose `huber` for Huber loss, `smooth_l1` for smooth L1 loss, and `l2` for MSE loss. The default is `l2`, which is the same as before. +- `huber_schedule`: Specify the scheduling method. Choose `exponential`, `constant`, or `SNR`. The default is `exponential`. +- `huber_c`: Specify the Huber's parameter. The default is `0.1`. #### 主要な変更点 @@ -211,14 +229,15 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG - また絶対パスの指定がある場合、そのパスが公開される可能性がありますので、相対パスを指定するか設定ファイルに記載することをお勧めします。このような場合は INFO ログが表示されます。 - 詳細は [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) および PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) をご覧ください。 - Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。 -- その他、マスクロス追加、DeepSpeed 対応、データセット設定の改善、画像タグ付けの改善などがあります。詳細は以下をご覧ください。 +- その他、マスクロス追加、Scheduled Huber Loss 追加、DeepSpeed 対応、データセット設定の改善、画像タグ付けの改善などがあります。詳細は以下をご覧ください。 #### 学習スクリプト - `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。 - `train_network.py` および `sdxl_train_network.py` で、state に U-Net および Text Encoder が含まれる不具合を修正しました。state の保存、読み込みが高速化され、ファイルサイズも小さくなり、また読み込み時のメモリ使用量も削減されます。 - DeepSpeed がサポートされました。PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) 、[#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) BootsofLagrangian 氏に感謝します。詳細は PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) をご覧ください。 -- 各学習スクリプトでマスクロスをサポートしました。PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) 詳細は [Masked loss](#masked-loss) をご覧ください。 +- 各学習スクリプトでマスクロスをサポートしました。PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) 詳細は [マスクロスについて](#マスクロスについて) をご覧ください。 +- 各学習スクリプトに Scheduled Huber Loss を追加しました。PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) ご提案いただいた kabachuha 氏、および議論を深めてくださった cheald 氏、drhead 氏を始めとする諸氏に感謝します。詳細は [Scheduled Huber Loss について](#scheduled-huber-loss-について) をご覧ください。 - 各学習スクリプトに、noise offset、ip noise gammaを、それぞれ 0~指定した値の範囲で変動させるオプション `--noise_offset_random_strength` および `--ip_noise_gamma_random_strength` が追加されました。 PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) KohakuBlueleaf 氏に感謝します。 - 各学習スクリプトに、学習終了時に state を保存する `--save_state_on_train_end` オプションが追加されました。 PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) gesen2egee 氏に感謝します。 - 各学習スクリプトで `--sample_every_n_epochs` および `--sample_every_n_steps` オプションに `0` 以下の数値を指定した時、警告を表示するとともにそれらを無視するよう変更しました。問題提起していただいた S-Del 氏に感謝します。 @@ -262,6 +281,26 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates. 最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。 +#### Scheduled Huber Loss について + +各学習スクリプトに、学習データ中の異常値や外れ値(data corruption)への耐性を高めるための手法、Scheduled Huber Lossが導入されました。 + +従来のMSE(L2)損失関数では、異常値の影響を大きく受けてしまい、生成画像の品質低下を招く恐れがありました。一方、Huber損失関数は異常値の影響を抑えられますが、画像の細部再現性が損なわれがちでした。 + +この手法ではHuber損失関数の適用を工夫し、学習の初期段階(ノイズが大きい場合)ではHuber損失を、後期段階ではMSEを用いるようスケジューリングすることで、異常値耐性と細部再現性のバランスを取ります。 + +実験の結果では、この手法が純粋なHuber損失やMSEと比べ、異常値を含むデータでより高い精度を達成することが確認されています。また計算コストの増加はわずかです。 + +具体的には、新たに追加された引数loss_type、huber_schedule、huber_cで、損失関数の種類(Huber, smooth L1, MSE)とスケジューリング方法(exponential, constant, SNR)を選択できます。これによりデータセットに応じた最適化が可能になります。 + +詳細は PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) をご覧ください。 + +- `loss_type` : 損失関数の種類を指定します。`huber` で Huber損失、`smooth_l1` で smooth L1 損失、`l2` で MSE 損失を選択します。デフォルトは `l2` で、従来と同様です。 +- `huber_schedule` : スケジューリング方法を指定します。`exponential` で指数関数的、`constant` で一定、`snr` で信号対雑音比に基づくスケジューリングを選択します。デフォルトは `exponential` です。 +- `huber_c` : Huber損失のパラメータを指定します。デフォルトは `0.1` です。 + +PR 内でいくつかの比較が共有されています。この機能を試す場合、最初は `--loss_type smooth_l1 --huber_schedule snr --huber_c 0.1` などで試してみるとよいかもしれません。 + ## Additional Information ### Naming of LoRA diff --git a/library/train_util.py b/library/train_util.py index 90e6818ad..9ce129bd9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3241,20 +3241,21 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: type=str, default="l2", choices=["l2", "huber", "smooth_l1"], - help="The type of loss to use and whether it's scheduled based on the timestep" + help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2", ) parser.add_argument( "--huber_schedule", type=str, default="exponential", choices=["constant", "exponential", "snr"], - help="The type of loss to use and whether it's scheduled based on the timestep" + help="The scheduling method for Huber loss (constant, exponential, or SNR-based). Only used when loss_type is 'huber' or 'smooth_l1'. default is exponential" + + " / Huber損失のスケジューリング方法(constant、exponential、またはSNRベース)。loss_typeが'huber'または'smooth_l1'の場合に有効、デフォルトはexponential", ) parser.add_argument( "--huber_c", type=float, default=0.1, - help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.", + help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", ) parser.add_argument( @@ -4862,39 +4863,39 @@ def save_sd_model_on_train_end_common( if args.huggingface_repo_id is not None: huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) + def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): - #TODO: if a huber loss is selected, it will use constant timesteps for each batch + # TODO: if a huber loss is selected, it will use constant timesteps for each batch # as. In the future there may be a smarter way - if args.loss_type == 'huber' or args.loss_type == 'smooth_l1': - timesteps = torch.randint( - min_timestep, max_timestep, (1,), device='cpu' - ) + if args.loss_type == "huber" or args.loss_type == "smooth_l1": + timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu") timestep = timesteps.item() if args.huber_schedule == "exponential": - alpha = - math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps huber_c = math.exp(-alpha * timestep) elif args.huber_schedule == "snr": alphas_cumprod = noise_scheduler.alphas_cumprod[timestep] sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 - huber_c = (1 - args.huber_c) / (1 + sigmas)**2 + args.huber_c + huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c elif args.huber_schedule == "constant": huber_c = args.huber_c else: - raise NotImplementedError(f'Unknown Huber loss schedule {args.huber_schedule}!') + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") timesteps = timesteps.repeat(b_size).to(device) - elif args.loss_type == 'l2': + elif args.loss_type == "l2": timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) - huber_c = 1 # may be anything, as it's not used + huber_c = 1 # may be anything, as it's not used else: - raise NotImplementedError(f'Unknown loss type {args.loss_type}') + raise NotImplementedError(f"Unknown loss type {args.loss_type}") timesteps = timesteps.long() return timesteps, huber_c + def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) @@ -4929,27 +4930,31 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): return noise, noisy_latents, timesteps, huber_c + # NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already -def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str="mean", loss_type:str="l2", huber_c:float=0.1): - - if loss_type == 'l2': +def conditional_loss( + model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1 +): + + if loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) - elif loss_type == 'huber': + elif loss_type == "huber": loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) - elif loss_type == 'smooth_l1': + elif loss_type == "smooth_l1": loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) else: - raise NotImplementedError(f'Unsupported Loss Type {loss_type}') + raise NotImplementedError(f"Unsupported Loss Type {loss_type}") return loss + def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): names = [] if including_unet: diff --git a/train_network.py b/train_network.py index 31d89276c..c99d37247 100644 --- a/train_network.py +++ b/train_network.py @@ -476,7 +476,7 @@ def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights if accelerator.is_main_process: remove_indices = [] - for i,model in enumerate(models): + for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): @@ -569,6 +569,11 @@ def load_model_hook(models, input_dir): "ss_scale_weight_norms": args.scale_weight_norms, "ss_ip_noise_gamma": args.ip_noise_gamma, "ss_debiased_estimation": bool(args.debiased_estimation_loss), + "ss_noise_offset_random_strength": args.noise_offset_random_strength, + "ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength, + "ss_loss_type": args.loss_type, + "ss_huber_schedule": args.huber_schedule, + "ss_huber_c": args.huber_c, } if use_user_config: @@ -873,7 +878,9 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss: loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3])