diff --git a/configs/RectifiedFlow_test.yaml b/configs/RectifiedFlow_test.yaml index af55342b..d0e8fae2 100644 --- a/configs/RectifiedFlow_test.yaml +++ b/configs/RectifiedFlow_test.yaml @@ -8,3 +8,7 @@ T_start: 0 T_start_infer: 0 sampling_steps: 10 time_scale_factor: 1000 + +diff_speedup: 100 +timestep_type: 'continuous' #discrete +loss_clip_min: 0.5 \ No newline at end of file diff --git a/modules/losses/diff_loss.py b/modules/losses/diff_loss.py index 50a46a24..05654992 100644 --- a/modules/losses/diff_loss.py +++ b/modules/losses/diff_loss.py @@ -3,6 +3,7 @@ import torch +from utils.hparams import hparams class DiffusionNoiseLoss(nn.Module): def __init__(self, loss_type): super().__init__() @@ -30,6 +31,7 @@ def l2_rf_norm(self, x_recon, noise, timestep): timestep=torch.clip(timestep, 0+eps, 1-eps) weights = 0.398942 / timestep / (1 - timestep) * torch.exp( -0.5 * torch.log(timestep / (1 - timestep)) ** 2) + eps + weights = torch.clip(weights, hparams['loss_clip_min'], ) return weights[:, None, None, None] * self.loss(x_recon, noise) def _forward(self, x_recon, noise, timestep=None):