-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add DPM scheduler with EDM formulation #7120
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally looks good. Let's make sure to add "Copied from ..." statements wherever applicable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good in principle!
else: | ||
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") | ||
|
||
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should deis
be accepted here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure, for now I followed the logic in existing dpm scheduler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we just map it to dpmsolver++
here so it should be fine; although I absolutely have no clue why "deis" is accepted as algorithm type (I looked back at the original PR that added it too and that did not help)
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
look good!
I left a comment about the _sigma_to_alpha_sigma_t
function - I think maybe we don't need this here and can simplify the math a little bit
return sigmas | ||
|
||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample | ||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work? asking because we precondition the inputs so sample
has a different scale now
return t | ||
|
||
def _sigma_to_alpha_sigma_t(self, sigma): | ||
alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohh nice! that's why this formula still works here
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
or maybe we do not need this function at all , just change the math directly inside the steps as
x_t = (sigma_t / sigma_s) * sample - (torch.exp(-h) - 1.0) * model_output
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) | ||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) | ||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t) | ||
lambda_s = torch.log(alpha_s) - torch.log(sigma_s) | ||
|
||
h = lambda_t - lambda_s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) | |
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) | |
lambda_t = torch.log(alpha_t) - torch.log(sigma_t) | |
lambda_s = torch.log(alpha_s) - torch.log(sigma_s) | |
h = lambda_t - lambda_s | |
h = torch.log(sigma_s) - torch.log(sigma_t) |
I'm not 100% sure the math is correct here but the idea is we can calculate h
from sigma directly
|
||
h = lambda_t - lambda_s | ||
if self.config.algorithm_type == "dpmsolver++": | ||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output | |
x_t = (sigma_t / sigma_s) * sample - (torch.exp(-h) - 1.0) * model_output |
assert noise is not None | ||
x_t = ( | ||
(sigma_t / sigma_s * torch.exp(-h)) * sample | ||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output | |
+ (1 - torch.exp(-2.0 * h)) * model_output |
self.sigmas[self.step_index - 1], | ||
) | ||
|
||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comments as #7120 (comment)
What does this PR do?