Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/RectifiedFlow' into RectifiedFlow
Browse files Browse the repository at this point in the history
  • Loading branch information
yqzhishen committed Apr 5, 2024
2 parents 821b3dd + b908cb8 commit 7e5bd96
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
4 changes: 4 additions & 0 deletions configs/RectifiedFlow_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ T_start: 0
T_start_infer: 0
sampling_steps: 10
time_scale_factor: 1000

diff_speedup: 100
timestep_type: 'continuous' #discrete
loss_clip_min: 0.5
2 changes: 2 additions & 0 deletions modules/losses/diff_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch


from utils.hparams import hparams
class DiffusionNoiseLoss(nn.Module):
def __init__(self, loss_type):
super().__init__()
Expand Down Expand Up @@ -30,6 +31,7 @@ def l2_rf_norm(self, x_recon, noise, timestep):
timestep=torch.clip(timestep, 0+eps, 1-eps)
weights = 0.398942 / timestep / (1 - timestep) * torch.exp(
-0.5 * torch.log(timestep / (1 - timestep)) ** 2) + eps
weights = torch.clip(weights, hparams['loss_clip_min'], )
return weights[:, None, None, None] * self.loss(x_recon, noise)

def _forward(self, x_recon, noise, timestep=None):
Expand Down

0 comments on commit 7e5bd96

Please sign in to comment.