From 0afee96e0fcc5c7a108371916ccd5d3dd8870104 Mon Sep 17 00:00:00 2001 From: Chris Heald Date: Sun, 7 Apr 2024 14:32:27 -0700 Subject: [PATCH] Add support for EDM2 timestep weighting network --- library/timestep_uncertainty.py | 110 ++++++++++++++++++++++++++++++++ library/train_util.py | 10 +++ train_network.py | 19 ++++++ 3 files changed, 139 insertions(+) create mode 100644 library/timestep_uncertainty.py diff --git a/library/timestep_uncertainty.py b/library/timestep_uncertainty.py new file mode 100644 index 000000000..c1e48e94b --- /dev/null +++ b/library/timestep_uncertainty.py @@ -0,0 +1,110 @@ +# adapted from https://github.com/NVlabs/edm2/blob/3a6682d3d25395df64863d3cea563bf3f3380769/training/networks_edm2.py + +import torch +import numpy as np +import os +from safetensors.torch import load_file + +#---------------------------------------------------------------------------- +# Normalize given tensor to unit magnitude with respect to the given +# dimensions. Default = all dimensions except the first. + +def normalize(x, dim=None, eps=1e-4): + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + + +class MPFourier(torch.nn.Module): + def __init__(self, num_channels, bandwidth=1): + super().__init__() + self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth) + self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels)) + + def forward(self, x): + y = x.to(torch.float32) + y = y.ger(self.freqs.to(torch.float32)) + y = y + self.phases.to(torch.float32) + y = y.cos() * np.sqrt(2) + return y.to(x.dtype) + +class MPConv(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel): + super().__init__() + self.out_channels = out_channels + self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel)) + + def forward(self, x, gain=1): + w = self.weight.to(torch.float32) + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(w)) # forced weight normalization + w = normalize(w) # traditional weight normalization + w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling + w = w.to(x.dtype) + if w.ndim == 2: + return x @ w.t() + assert w.ndim == 4 + return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,)) + +class TimestepUncertaintyLossNetwork(torch.nn.Module): + def __init__(self, + logvar_channels = 128, # Intermediate dimensionality for uncertainty estimation. + ): + super().__init__() + self.logvar_fourier = MPFourier(logvar_channels) + self.logvar_linear = MPConv(logvar_channels, 1, kernel=[]) + + def forward(self, sigma): + c_noise = sigma.reshape(-1, 1, 1, 1).flatten().log() / 4 + logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1) + return logvar + + def loss(self, sigma, loss): + logvar = self.forward(sigma) + return loss / logvar.exp() + logvar + + def load_weights(self, file, dtype=None): + if not os.path.exists(file): + print(f"WARNING: Could not load weights from '{file}' because the file does not exist.") + return + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + state_dict = load_file(file) + else: + state_dict = torch.load(file) + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to(dtype) + state_dict[key] = v + + self.load_state_dict(state_dict) + + def save_weights(self, file, dtype=torch.float32, metadata={}): + metadata = {} + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) diff --git a/library/train_util.py b/library/train_util.py index 8aada7b5f..f5d080929 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3374,6 +3374,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="If set, dynamically learn the value for `multires_noise_discount`. 7e-2..5e-2 is a good starting point", ) + parser.add_argument( + "--sigma_uncertainty_model", + type=str, + default=None, + ) + parser.add_argument( + "--train_sigma_uncertainty", + action="store_true", + help="Train sigma uncertainty" + ) if support_dreambooth: # DreamBooth training diff --git a/train_network.py b/train_network.py index 99c79ca9a..faa575a8c 100644 --- a/train_network.py +++ b/train_network.py @@ -14,6 +14,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device +from library.timestep_uncertainty import TimestepUncertaintyLossNetwork init_ipex() @@ -368,6 +369,16 @@ def train(self, args): network.get_parameter("noise_discount"), ], "lr": args.multires_discount_lr * args.gradient_accumulation_steps}) + if args.sigma_uncertainty_model: + timestep_uncertainty_loss = TimestepUncertaintyLossNetwork().to(accelerator.device) + timestep_uncertainty_loss.load_weights(args.sigma_uncertainty_model) + if args.train_sigma_uncertainty: + timestep_uncertainty_loss.train() + # important that you don't have weight decay here, this model has forced weight norm and its weights will not grow in magnitude - drhead + trainable_params.append({"params": timestep_uncertainty_loss.parameters(), "lr": 1e-3, "weight_decay": 0.0}) + else: + timestep_uncertainty_loss.eval() + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する @@ -812,6 +823,8 @@ def remove_model(old_ckpt_name): # For --sample_at_first self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + sigmas = ((1 - noise_scheduler.alphas_cumprod) / noise_scheduler.alphas_cumprod).sqrt().to(accelerator.device) + # training loop for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") @@ -915,6 +928,9 @@ def remove_model(old_ckpt_name): if args.debiased_estimation_loss: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + if args.sigma_uncertainty_model: + loss = timestep_uncertainty_loss.loss(sigmas[timesteps], loss) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) @@ -988,6 +1004,9 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1) + if args.train_sigma_uncertainty: + accelerator.unwrap_model(timestep_uncertainty_loss).save_weights(args.sigma_uncertainty_model) + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) if remove_epoch_no is not None: remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)