From 4f0457a996596561a7f2074ea33d44e34b645ca5 Mon Sep 17 00:00:00 2001 From: autumn <109412646+autumn-2-net@users.noreply.github.com> Date: Fri, 5 Apr 2024 20:14:03 +0800 Subject: [PATCH] add limit loss weights --- configs/RectifiedFlow_test.yaml | 3 ++- modules/losses/diff_loss.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/configs/RectifiedFlow_test.yaml b/configs/RectifiedFlow_test.yaml index bd19c775..61d6a74e 100644 --- a/configs/RectifiedFlow_test.yaml +++ b/configs/RectifiedFlow_test.yaml @@ -8,4 +8,5 @@ diffusion_type: 'RectifiedFlow' #ddpm diff_accelerator: 'rk4' # euler rk2 rk5 rk5_fp64 euler_fp64 rk4_fp64 rk2_fp64 diff_speedup: 100 -timestep_type: 'continuous' #discrete \ No newline at end of file +timestep_type: 'continuous' #discrete +loss_clip_min: 0.5 \ No newline at end of file diff --git a/modules/losses/diff_loss.py b/modules/losses/diff_loss.py index 50a46a24..05654992 100644 --- a/modules/losses/diff_loss.py +++ b/modules/losses/diff_loss.py @@ -3,6 +3,7 @@ import torch +from utils.hparams import hparams class DiffusionNoiseLoss(nn.Module): def __init__(self, loss_type): super().__init__() @@ -30,6 +31,7 @@ def l2_rf_norm(self, x_recon, noise, timestep): timestep=torch.clip(timestep, 0+eps, 1-eps) weights = 0.398942 / timestep / (1 - timestep) * torch.exp( -0.5 * torch.log(timestep / (1 - timestep)) ** 2) + eps + weights = torch.clip(weights, hparams['loss_clip_min'], ) return weights[:, None, None, None] * self.loss(x_recon, noise) def _forward(self, x_recon, noise, timestep=None):