diff --git a/fine_tune.py b/fine_tune.py index 1acf478f4..ff33eb9c9 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -19,7 +19,8 @@ ConfigSanitizer, BlueprintGenerator, ) - +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight def collate_fn(examples): return examples[0] @@ -304,6 +305,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + if args.min_snr_gamma: + loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] @@ -396,6 +400,8 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py new file mode 100644 index 000000000..f60ec7436 --- /dev/null +++ b/library/custom_train_functions.py @@ -0,0 +1,17 @@ +import torch +import argparse + +def apply_snr_weight(loss, latents, noisy_latents, gamma): + sigma = torch.sub(noisy_latents, latents) #find noise as applied by scheduler + zeros = torch.zeros_like(sigma) + alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment + sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment + snr = torch.div(alpha_mean_sq,sigma_mean_sq) #Signal to Noise Ratio = ratio of Mean Squares + gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) + snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper + loss = loss * snr_weight + print(snr_weight) + return loss + +def add_custom_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--min_snr_gamma", type=float, default=0, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") \ No newline at end of file diff --git a/library/train_util.py b/library/train_util.py index ffe81d693..a0e98cb12 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1963,7 +1963,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" ) - parser.add_argument("--min_snr_gamma", type=float, default=0, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") + def verify_training_args(args: argparse.Namespace): if args.v_parameterization and not args.v2: diff --git a/train_db.py b/train_db.py index 527f8e9bc..ee9beda9d 100644 --- a/train_db.py +++ b/train_db.py @@ -21,7 +21,8 @@ ConfigSanitizer, BlueprintGenerator, ) - +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight def collate_fn(examples): return examples[0] @@ -291,6 +292,9 @@ def train(args): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights + if args.min_snr_gamma: + loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) @@ -390,6 +394,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument( "--no_token_padding", diff --git a/train_network.py b/train_network.py index dce706186..715da8c11 100644 --- a/train_network.py +++ b/train_network.py @@ -23,7 +23,8 @@ ConfigSanitizer, BlueprintGenerator, ) - +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight def collate_fn(examples): return examples[0] @@ -548,16 +549,9 @@ def train(args): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights - gamma = args.min_snr_gamma - if gamma: - sigma = torch.sub(noisy_latents, latents) #find noise as applied - zeros = torch.zeros_like(sigma) - alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square - sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square - snr = torch.div(alpha_mean_sq,sigma_mean_sq) #Signal to Noise Ratio = ratio of Mean Squares - gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) - snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper - loss = loss * snr_weight + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -662,6 +656,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument( diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 85f0d57c3..5fe662f6a 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -17,6 +17,8 @@ ConfigSanitizer, BlueprintGenerator, ) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight imagenet_templates_small = [ "a photo of a {}", @@ -377,6 +379,9 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights @@ -534,6 +539,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument( "--save_model_as",