-
Notifications
You must be signed in to change notification settings - Fork 947
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Soft min SNR gamma #1068
base: main
Are you sure you want to change the base?
Soft min SNR gamma #1068
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,6 +68,13 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False | |
return loss | ||
|
||
|
||
def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): | ||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) | ||
soft_min_snr_gamma_weight = 1 / (torch.pow(snr if v_prediction is False else snr + 1, 2) + (1 / float(gamma))) | ||
loss = loss * soft_min_snr_gamma_weight | ||
return loss | ||
|
||
|
||
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): | ||
scale = get_snr_scale(timesteps, noise_scheduler) | ||
loss = loss * scale | ||
|
@@ -106,6 +113,12 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted | |
default=None, | ||
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", | ||
) | ||
parser.add_argument( | ||
"--soft_min_snr_gamma", | ||
type=float, | ||
default=None, | ||
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't recommend
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated to recommend 1. Mistakenly I copied the help function from the min_snr_gamma option. Thank you! |
||
) | ||
parser.add_argument( | ||
"--scale_v_pred_loss_like_noise_pred", | ||
action="store_true", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The math here is incorrect. SNR is equal to the whole expression
1/sigma**2
, notsigma
(at least based on the fact that here they use Min(1/sigma**2, gamma) and in the min-snr paper they use Min(SNR, gamma). The variable names are inconsistent between papers so I don't blame you for getting them confused).The correct weight should be:
Finally, the given formulation for soft-min-snr is for
x_0
prediction. We use epsilon or v-prediction, which according to the original min-snr paper means we need to divide by SNR or SNR+1 respectively, so the final weight calculation should be:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried this formula
And produced the same loss curve as the current implementation. But the paper says it should match up except for the ones closer to the transition so the loss curves should be similar. Still seeing a difference with the Min SNR version though.
38 = weight = snr * gamma / (snr + gamma)
37 = the current PR version
35 = Min SNR version.
(38 and 37 are overlapping in the graph)

Maybe there's something else that is missing in these calculations.
Here is some example
snr,gamma
to use with the following test scriptsnr.txt
And using this code to test the formulas.
I don't know what I'm doing, I think, with the math but trying to learn.