diff --git a/fine_tune.py b/fine_tune.py index a44868415..8ae1bb29e 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -20,7 +20,8 @@ ConfigSanitizer, BlueprintGenerator, ) - +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight def train(args): train_util.verify_training_args(args) @@ -309,6 +310,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] @@ -401,6 +405,8 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py new file mode 100644 index 000000000..fd4f6156b --- /dev/null +++ b/library/custom_train_functions.py @@ -0,0 +1,18 @@ +import torch +import argparse + +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + snr = torch.stack([all_snr[t] for t in timesteps]) + gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) + snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper + loss = loss * snr_weight + return loss + +def add_custom_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") diff --git a/library/train_util.py b/library/train_util.py index 131a263b0..6a5679d3a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2001,7 +2001,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" ) - + def verify_training_args(args: argparse.Namespace): if args.v_parameterization and not args.v2: diff --git a/train_db.py b/train_db.py index 904fcc60a..f441d5d6b 100644 --- a/train_db.py +++ b/train_db.py @@ -22,7 +22,8 @@ ConfigSanitizer, BlueprintGenerator, ) - +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight def train(args): train_util.verify_training_args(args) @@ -296,6 +297,10 @@ def train(args): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) @@ -395,6 +400,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument( "--no_token_padding", diff --git a/train_network.py b/train_network.py index 8d1502cae..20ad2c4d9 100644 --- a/train_network.py +++ b/train_network.py @@ -24,6 +24,9 @@ ConfigSanitizer, BlueprintGenerator, ) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight + # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): @@ -492,7 +495,6 @@ def train(args): noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) - if accelerator.is_main_process: accelerator.init_trackers("network_train") @@ -534,7 +536,6 @@ def train(args): # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) @@ -554,6 +555,9 @@ def train(args): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -658,6 +662,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument( diff --git a/train_textual_inversion.py b/train_textual_inversion.py index a4ef45e3b..681bc628c 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -18,6 +18,8 @@ ConfigSanitizer, BlueprintGenerator, ) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight imagenet_templates_small = [ "a photo of a {}", @@ -383,6 +385,9 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights @@ -540,6 +545,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument( "--save_model_as",