-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add EDMEulerScheduler #7109
Add EDMEulerScheduler #7109
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
s_churn (`float`): | ||
s_tmin (`float`): | ||
s_tmax (`float`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to have docstrings.
|
||
return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) | ||
|
||
def add_noise( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👌
if self.config.prediction_type == "epsilon": | ||
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 | ||
elif self.config.prediction_type == "v_prediction": | ||
c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 | ||
else: | ||
raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess users will have to do this bit manually during training when experimenting with different prediction types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Design looks very nice to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking good to me:)
question:
for v_prediction
: does it work the same as in what's currently implemented in EulerDiscreteScheduler
?
@yiyixuxu Yes, that's right. Actually not sure if we should put it here, it's used to SVD and follows EDM as well. |
I think okay to put it there as it helps the community to experiment with different prediction types and provide us feedback if needed. We can iterate on top of that. |
Co-authored-by: Daniel Gu dgu8957@gmail.com
* Add EDMEulerScheduler * address review comments * fix import * fix test * add tests * add co-author Co-authored-by: @dg845 dgu8957@gmail.com
What does this PR do?
This PR is based on #4481 by @dg845 !
This PR adds Euler scheduler with EDM formulation. The difference between this and
EulerDiscreteScheduler
is thatEulerDiscreteScheduler
was essentially designed for DDPM models to use euler-style sampling algorithms with discrete timesteps.EDMEulerScheduler
follows the EDM formulation as closely as possible and is solely intended for models that use EDM formulation, like SVD. It does not support epsilon scaling as that's already covered byEulerDiscreteScheduler
. So models that still DDPM (the alphas and betas schedules) should keep usingEulerDiscreteScheduler
Why add a new scheduler class ?
While it is possible to support this in
EulerDiscreteScheduler
by introducing an argument maybe calledscaling_type
, which could correspond toepsilon (ddpm)
v-scaling
oredm
, the issue is that:edm
, the model is conditioned on continuous noise scales rather than discrete timesteps.EulerDiscreteScheduler
that is not required forEDM
, such as computing betas, rescaling for zero terminal SNR, interpolating sigmas etc.Hence with the current schedule API if we support full EDM formulation in
EulerDiscreteScheduler
, the code will be confusing to follow with lot's ofif/else
branches.API
These pre-conditioning methods can be used during training to easily scale the input/output based on sigma values, which the user is free to sample in any way they want during training.
We could also consider adding a method to help with sigma sampling as per here
The rest of the API follows the existing scheduler API.
I would love to have your thoughts here @yiyixuxu @sayakpaul @dhruvrnaik @dg845