From 90b18795fce516cb00735dc43a6ee76ecae8ec83 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Sun, 7 Apr 2024 07:54:21 +0300 Subject: [PATCH] Add option to use Scheduled Huber Loss in all training pipelines to improve resilience to data corruption (#1228) * add huber loss and huber_c compute to train_util * add reduction modes * add huber_c retrieval from timestep getter * move get timesteps and huber to own function * add conditional loss to all training scripts * add cond loss to train network * add (scheduled) huber_loss to args * fixup twice timesteps getting * PHL-schedule should depend on noise scheduler's num timesteps * *2 multiplier to huber loss cause of 1/2 a^2 conv. The Taylor expansion of sqrt near zero gives 1/2 a^2, which differs from a^2 of the standard MSE loss. This change scales them better against one another * add option for smooth l1 (huber / delta) * unify huber scheduling * add snr huber scheduler --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com> --- fine_tune.py | 6 +-- library/train_util.py | 79 ++++++++++++++++++++++++++-- sdxl_train.py | 6 +-- sdxl_train_control_net_lllite.py | 4 +- sdxl_train_control_net_lllite_old.py | 4 +- train_controlnet.py | 11 ++-- train_db.py | 4 +- train_network.py | 4 +- train_textual_inversion.py | 4 +- train_textual_inversion_XTI.py | 4 +- 10 files changed, 96 insertions(+), 30 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 3c4a5a26b..c7e6bbd2e 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -354,7 +354,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -368,7 +368,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -380,7 +380,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: diff --git a/library/train_util.py b/library/train_util.py index c13bb68ee..90e6818ad 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3236,6 +3236,26 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))", ) + parser.add_argument( + "--loss_type", + type=str, + default="l2", + choices=["l2", "huber", "smooth_l1"], + help="The type of loss to use and whether it's scheduled based on the timestep" + ) + parser.add_argument( + "--huber_schedule", + type=str, + default="exponential", + choices=["constant", "exponential", "snr"], + help="The type of loss to use and whether it's scheduled based on the timestep" + ) + parser.add_argument( + "--huber_c", + type=float, + default=0.1, + help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.", + ) parser.add_argument( "--lowram", @@ -4842,6 +4862,38 @@ def save_sd_model_on_train_end_common( if args.huggingface_repo_id is not None: huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) +def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): + + #TODO: if a huber loss is selected, it will use constant timesteps for each batch + # as. In the future there may be a smarter way + + if args.loss_type == 'huber' or args.loss_type == 'smooth_l1': + timesteps = torch.randint( + min_timestep, max_timestep, (1,), device='cpu' + ) + timestep = timesteps.item() + + if args.huber_schedule == "exponential": + alpha = - math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + huber_c = math.exp(-alpha * timestep) + elif args.huber_schedule == "snr": + alphas_cumprod = noise_scheduler.alphas_cumprod[timestep] + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + huber_c = (1 - args.huber_c) / (1 + sigmas)**2 + args.huber_c + elif args.huber_schedule == "constant": + huber_c = args.huber_c + else: + raise NotImplementedError(f'Unknown Huber loss schedule {args.huber_schedule}!') + + timesteps = timesteps.repeat(b_size).to(device) + elif args.loss_type == 'l2': + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) + huber_c = 1 # may be anything, as it's not used + else: + raise NotImplementedError(f'Unknown loss type {args.loss_type}') + timesteps = timesteps.long() + + return timesteps, huber_c def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): # Sample noise that we'll add to the latents @@ -4862,8 +4914,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device) - timesteps = timesteps.long() + timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -4876,8 +4927,28 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - return noise, noisy_latents, timesteps - + return noise, noisy_latents, timesteps, huber_c + +# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already +def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str="mean", loss_type:str="l2", huber_c:float=0.1): + + if loss_type == 'l2': + loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) + elif loss_type == 'huber': + loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) + if reduction == "mean": + loss = torch.mean(loss) + elif reduction == "sum": + loss = torch.sum(loss) + elif loss_type == 'smooth_l1': + loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) + if reduction == "mean": + loss = torch.mean(loss) + elif reduction == "sum": + loss = torch.sum(loss) + else: + raise NotImplementedError(f'Unsupported Loss Type {loss_type}') + return loss def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): names = [] diff --git a/sdxl_train.py b/sdxl_train.py index f6d277494..46d7860be 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -582,7 +582,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -600,7 +600,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) if args.masked_loss: loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -616,7 +616,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index e880b57de..f89c3628f 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -439,7 +439,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -458,7 +458,7 @@ def remove_model(old_ckpt_name): else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 0ea64b824..e85e978c1 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -406,7 +406,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -426,7 +426,7 @@ def remove_model(old_ckpt_name): else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_controlnet.py b/train_controlnet.py index 90cac0410..f4c94e8d9 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -420,13 +420,8 @@ def remove_model(old_ckpt_name): ) # Sample a random timestep for each image - timesteps = torch.randint( - 0, - noise_scheduler.config.num_train_timesteps, - (b_size,), - device=latents.device, - ) - timesteps = timesteps.long() + timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device) + # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) @@ -457,7 +452,7 @@ def remove_model(old_ckpt_name): else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_db.py b/train_db.py index c3b7339f3..1de504ed8 100644 --- a/train_db.py +++ b/train_db.py @@ -346,7 +346,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -358,7 +358,7 @@ def train(args): else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) if args.masked_loss: loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/train_network.py b/train_network.py index fcf4cd9b6..31d89276c 100644 --- a/train_network.py +++ b/train_network.py @@ -843,7 +843,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -873,7 +873,7 @@ def remove_model(old_ckpt_name): else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) if args.masked_loss: loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 02edf9525..10fce2677 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -572,7 +572,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -588,7 +588,7 @@ def remove_model(old_ckpt_name): else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) if args.masked_loss: loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index f0723f2a7..ddd03d532 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -461,7 +461,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -473,7 +473,7 @@ def remove_model(old_ckpt_name): else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) if args.masked_loss: loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3])